langchainrb 0.5.0 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b0a2fe8026e861c9d97465bce7da08a0b077492d6f7cf8fb42c45dbfdfe6749f
4
- data.tar.gz: c04099c44a847bd9c05e8594859f92ca1f54d338c463ce59a375c2cb9731b1ad
3
+ metadata.gz: d36de4206b792714ba9b6773c03272e9638b14caf7140e0bc00c3e767aa5fdef
4
+ data.tar.gz: 819fab9de55a34e4e6dc865febc19bb9979df55fa8fc6a753774cf1961c40103
5
5
  SHA512:
6
- metadata.gz: dec375b2b7cae377cf31f3f8ed0a6ac9d79215c945e7c0da78ed1fbad3c502ecfcc5ce5318c55a9a634db88fea1f5fbbeed7a0f7dc6ab8096c909e0a3ff02154
7
- data.tar.gz: 558a0f6ddf90ad044f9e2cc7c6ca678958472748d2e430a3cbb4308290898b9b22a204a94a5b5943f06137f7da5238d038a65b8c79d96f8a3705499d95cfb597
6
+ metadata.gz: 6e180b41bbca96bd5523c276923f223bbebe470314086c6a909df440890793bcc70dbd66ecf59bf5d0fd52426650cc5d2684c56cc8fc643209cc1679527cbef4
7
+ data.tar.gz: af5db76c2b22b5c7bdc1170de437921e8464a16566f46a5cad465d69e6da47c97a82f7331a5ea5747840e58acc71463aa8456b03e9bc8851efda7b734e5d23cc
data/CHANGELOG.md CHANGED
@@ -1,5 +1,13 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.5.2] - 2023-06-07
4
+ - 🗣️ LLMs
5
+ - Auto-calculate the max_tokens: setting to be passed on to OpenAI
6
+
7
+ ## [0.5.1] - 2023-06-06
8
+ - 🛠️ Tools
9
+ - Modified Tool usage. Agents now accept Tools instances instead of Tool strings.
10
+
3
11
  ## [0.5.0] - 2023-06-05
4
12
  - [BREAKING] LLMs are now passed as objects to Vectorsearch classes instead of `llm: :name, llm_api_key:` previously
5
13
  - 📋 Prompts
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- langchainrb (0.5.0)
4
+ langchainrb (0.5.2)
5
5
  colorize (~> 0.8.1)
6
6
  tiktoken_ruby (~> 0.0.5)
7
7
 
data/README.md CHANGED
@@ -256,7 +256,15 @@ Agents are semi-autonomous bots that can respond to user questions and use avail
256
256
  Add `gem "ruby-openai"`, `gem "eqn"`, and `gem "google_search_results"` to your Gemfile
257
257
 
258
258
  ```ruby
259
- agent = Langchain::Agent::ChainOfThoughtAgent.new(llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"]), tools: ['search', 'calculator'])
259
+ search_tool = Langchain::Tool::SerpApi.new(api_key: ENV["SERPAPI_API_KEY"])
260
+ calculator = Langchain::Tool::Calculator.new
261
+
262
+ openai = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
263
+
264
+ agent = Langchain::Agent::ChainOfThoughtAgent.new(
265
+ llm: openai,
266
+ tools: [search_tool, calculator]
267
+ )
260
268
 
261
269
  agent.tools
262
270
  # => ["search", "calculator"]
@@ -271,11 +279,12 @@ agent.run(question: "How many full soccer fields would be needed to cover the di
271
279
  Add `gem "sequel"` to your Gemfile
272
280
 
273
281
  ```ruby
274
- agent = Langchain::Agent::SQLQueryAgent.new(llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"]), db_connection_string: "postgres://user:password@localhost:5432/db_name")
282
+ database = Langchain::Tool::Database.new(connection_string: "postgres://user:password@localhost:5432/db_name")
275
283
 
284
+ agent = Langchain::Agent::SQLQueryAgent.new(llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"]), db: database)
276
285
  ```
277
286
  ```ruby
278
- agent.ask(question: "How many users have a name with length greater than 5 in the users table?")
287
+ agent.run(question: "How many users have a name with length greater than 5 in the users table?")
279
288
  #=> "14 users have a name with length greater than 5 in the users table."
280
289
  ```
281
290
 
@@ -1,10 +1,10 @@
1
1
  require "langchain"
2
2
 
3
3
  # Create a prompt with a few shot examples
4
- prompt = Prompt::FewShotPromptTemplate.new(
4
+ prompt = Langchain::Prompt::FewShotPromptTemplate.new(
5
5
  prefix: "Write antonyms for the following words.",
6
6
  suffix: "Input: {adjective}\nOutput:",
7
- example_prompt: Prompt::PromptTemplate.new(
7
+ example_prompt: Langchain::Prompt::PromptTemplate.new(
8
8
  input_variables: ["input", "output"],
9
9
  template: "Input: {input}\nOutput: {output}"
10
10
  ),
@@ -32,5 +32,5 @@ prompt.format(adjective: "good")
32
32
  prompt.save(file_path: "spec/fixtures/prompt/few_shot_prompt_template.json")
33
33
 
34
34
  # Loading a new prompt template using a JSON file
35
- prompt = Prompt.load_from_path(file_path: "spec/fixtures/prompt/few_shot_prompt_template.json")
35
+ prompt = Langchain::Prompt.load_from_path(file_path: "spec/fixtures/prompt/few_shot_prompt_template.json")
36
36
  prompt.prefix # "Write antonyms for the following words."
@@ -1,15 +1,15 @@
1
1
  require "langchain"
2
2
 
3
3
  # Create a prompt with one input variable
4
- prompt = Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke.", input_variables: ["adjective"])
4
+ prompt = Langchain::Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke.", input_variables: ["adjective"])
5
5
  prompt.format(adjective: "funny") # "Tell me a funny joke."
6
6
 
7
7
  # Create a prompt with multiple input variables
8
- prompt = Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke about {content}.", input_variables: ["adjective", "content"])
8
+ prompt = Langchain::Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke about {content}.", input_variables: ["adjective", "content"])
9
9
  prompt.format(adjective: "funny", content: "chickens") # "Tell me a funny joke about chickens."
10
10
 
11
11
  # Creating a PromptTemplate using just a prompt and no input_variables
12
- prompt = Prompt::PromptTemplate.from_template("Tell me a {adjective} joke about {content}.")
12
+ prompt = Langchain::Prompt::PromptTemplate.from_template("Tell me a {adjective} joke about {content}.")
13
13
  prompt.input_variables # ["adjective", "content"]
14
14
  prompt.format(adjective: "funny", content: "chickens") # "Tell me a funny joke about chickens."
15
15
 
@@ -17,5 +17,9 @@ prompt.format(adjective: "funny", content: "chickens") # "Tell me a funny joke a
17
17
  prompt.save(file_path: "spec/fixtures/prompt/prompt_template.json")
18
18
 
19
19
  # Loading a new prompt template using a JSON file
20
- prompt = Prompt.load_from_path(file_path: "spec/fixtures/prompt/prompt_template.json")
20
+ prompt = Langchain::Prompt.load_from_path(file_path: "spec/fixtures/prompt/prompt_template.json")
21
+ prompt.input_variables # ["adjective", "content"]
22
+
23
+ # Loading a new prompt template using a YAML file
24
+ prompt = Langchain::Prompt.load_from_path(file_path: "spec/fixtures/prompt/prompt_template.yaml")
21
25
  prompt.input_variables # ["adjective", "content"]
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "chroma-db", "~> 0.3.0"` to your Gemfile
5
5
 
6
6
  # Instantiate the Chroma client
7
- chroma = Vectorsearch::Chroma.new(
7
+ chroma = Langchain::Vectorsearch::Chroma.new(
8
8
  url: ENV["CHROMA_URL"],
9
9
  index_name: "documents",
10
10
  llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "pinecone"` to your Gemfile
5
5
 
6
6
  # Instantiate the Qdrant client
7
- pinecone = Vectorsearch::Pinecone.new(
7
+ pinecone = Langchain::Vectorsearch::Pinecone.new(
8
8
  environment: ENV["PINECONE_ENVIRONMENT"],
9
9
  api_key: ENV["PINECONE_API_KEY"],
10
10
  index_name: "recipes",
@@ -37,7 +37,7 @@ pinecone.ask(
37
37
  )
38
38
 
39
39
  # Generate your an embedding and search by it
40
- openai = LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
40
+ openai = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
41
41
  embedding = openai.embed(text: "veggie")
42
42
 
43
43
  pinecone.similarity_search_by_vector(
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "qdrant-ruby"` to your Gemfile
5
5
 
6
6
  # Instantiate the Qdrant client
7
- qdrant = Vectorsearch::Qdrant.new(
7
+ qdrant = Langchain::Vectorsearch::Qdrant.new(
8
8
  url: ENV["QDRANT_URL"],
9
9
  api_key: ENV["QDRANT_API_KEY"],
10
10
  index_name: "recipes",
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "weaviate-ruby"` to your Gemfile
5
5
 
6
6
  # Instantiate the Weaviate client
7
- weaviate = Vectorsearch::Weaviate.new(
7
+ weaviate = Langchain::Vectorsearch::Weaviate.new(
8
8
  url: ENV["WEAVIATE_URL"],
9
9
  api_key: ENV["WEAVIATE_API_KEY"],
10
10
  index_name: "Recipes",
@@ -39,11 +39,8 @@ module Langchain::Agent
39
39
 
40
40
  loop do
41
41
  Langchain.logger.info("[#{self.class.name}]".red + ": Sending the prompt to the #{llm.class} LLM")
42
- response = llm.complete(
43
- prompt: prompt,
44
- stop_sequences: ["Observation:"],
45
- max_tokens: 500
46
- )
42
+
43
+ response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
47
44
 
48
45
  # Append the response to the prompt
49
46
  prompt += response
@@ -55,10 +52,11 @@ module Langchain::Agent
55
52
  # Find the input to the action in the "Action Input: [action_input]" format
56
53
  action_input = response.match(/Action Input: "?(.*)"?/)&.send(:[], -1)
57
54
 
58
- # Retrieve the Tool::[ToolName] class and call `execute`` with action_input as the input
59
- tool = Langchain::Tool.const_get(Langchain::Tool::Base::TOOLS[action.strip])
60
- Langchain.logger.info("[#{self.class.name}]".red + ": Invoking \"#{tool}\" Tool with \"#{action_input}\"")
55
+ # Find the Tool and call `execute`` with action_input as the input
56
+ tool = tools.find { |tool| tool.tool_name == action.strip }
57
+ Langchain.logger.info("[#{self.class.name}]".red + ": Invoking \"#{tool.class}\" Tool with \"#{action_input}\"")
61
58
 
59
+ # Call `execute` with action_input as the input
62
60
  result = tool.execute(input: action_input)
63
61
 
64
62
  # Append the Observation to the prompt
@@ -81,12 +79,16 @@ module Langchain::Agent
81
79
  # @param tools [Array] Tools to use
82
80
  # @return [String] Prompt
83
81
  def create_prompt(question:, tools:)
82
+ tool_list = tools.map(&:tool_name)
83
+
84
84
  prompt_template.format(
85
85
  date: Date.today.strftime("%B %d, %Y"),
86
86
  question: question,
87
- tool_names: "[#{tools.join(", ")}]",
87
+ tool_names: "[#{tool_list.join(", ")}]",
88
88
  tools: tools.map do |tool|
89
- "#{tool}: #{Langchain::Tool.const_get(Langchain::Tool::Base::TOOLS[tool]).const_get(:DESCRIPTION)}"
89
+ tool_name = tool.tool_name
90
+ tool_description = tool.class.const_get(:DESCRIPTION)
91
+ "#{tool_name}: #{tool_description}"
90
92
  end.join("\n")
91
93
  )
92
94
  end
@@ -4,26 +4,30 @@ module Langchain::Agent
4
4
  class SQLQueryAgent < Base
5
5
  attr_reader :llm, :db, :schema
6
6
 
7
+ #
7
8
  # Initializes the Agent
8
9
  #
9
10
  # @param llm [Object] The LLM client to use
10
- # @param db_connection_string [String] Database connection info
11
- def initialize(llm:, db_connection_string:)
11
+ # @param db [Object] Database connection info
12
+ #
13
+ def initialize(llm:, db:)
12
14
  @llm = llm
13
- @db = Langchain::Tool::Database.new(db_connection_string)
15
+ @db = db
14
16
  @schema = @db.schema
15
17
  end
16
18
 
19
+ #
17
20
  # Ask a question and get an answer
18
21
  #
19
22
  # @param question [String] Question to ask the LLM/Database
20
23
  # @return [String] Answer to the question
21
- def ask(question:)
24
+ #
25
+ def run(question:)
22
26
  prompt = create_prompt_for_sql(question: question)
23
27
 
24
28
  # Get the SQL string to execute
25
29
  Langchain.logger.info("[#{self.class.name}]".red + ": Passing the inital prompt to the #{llm.class} LLM")
26
- sql_string = llm.complete(prompt: prompt, max_tokens: 500)
30
+ sql_string = llm.complete(prompt: prompt)
27
31
 
28
32
  # Execute the SQL string and collect the results
29
33
  Langchain.logger.info("[#{self.class.name}]".red + ": Passing the SQL to the Database: #{sql_string}")
@@ -32,7 +36,7 @@ module Langchain::Agent
32
36
  # Pass the results and get the LLM to synthesize the answer to the question
33
37
  Langchain.logger.info("[#{self.class.name}]".red + ": Passing the synthesize prompt to the #{llm.class} LLM with results: #{results}")
34
38
  prompt2 = create_prompt_for_answer(question: question, sql_query: sql_string, results: results)
35
- llm.complete(prompt: prompt2, max_tokens: 500)
39
+ llm.complete(prompt: prompt2)
36
40
  end
37
41
 
38
42
  private
@@ -35,7 +35,7 @@ module Langchain::LLM
35
35
  def embed(text:, **params)
36
36
  parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
37
37
 
38
- Langchain::Utils::TokenLengthValidator.validate!(text, parameters[:model])
38
+ Langchain::Utils::TokenLengthValidator.validate_max_tokens!(text, parameters[:model])
39
39
 
40
40
  response = client.embeddings(parameters: parameters.merge(params))
41
41
  response.dig("data").first.dig("embedding")
@@ -50,9 +50,8 @@ module Langchain::LLM
50
50
  def complete(prompt:, **params)
51
51
  parameters = compose_parameters DEFAULTS[:completion_model_name], params
52
52
 
53
- Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
54
-
55
53
  parameters[:prompt] = prompt
54
+ parameters[:max_tokens] = Langchain::Utils::TokenLengthValidator.validate_max_tokens!(prompt, parameters[:model])
56
55
 
57
56
  response = client.completions(parameters: parameters)
58
57
  response.dig("choices", 0, "text")
@@ -67,9 +66,8 @@ module Langchain::LLM
67
66
  def chat(prompt:, **params)
68
67
  parameters = compose_parameters DEFAULTS[:chat_completion_model_name], params
69
68
 
70
- Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
71
-
72
69
  parameters[:messages] = [{role: "user", content: prompt}]
70
+ parameters[:max_tokens] = Langchain::Utils::TokenLengthValidator.validate_max_tokens!(prompt, parameters[:model])
73
71
 
74
72
  response = client.chat(parameters: parameters)
75
73
  response.dig("choices", 0, "message", "content")
@@ -87,12 +85,7 @@ module Langchain::LLM
87
85
  )
88
86
  prompt = prompt_template.format(text: text)
89
87
 
90
- complete(
91
- prompt: prompt,
92
- temperature: DEFAULTS[:temperature],
93
- # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
94
- max_tokens: 2048
95
- )
88
+ complete(prompt: prompt, temperature: DEFAULTS[:temperature])
96
89
  end
97
90
 
98
91
  private
@@ -6,47 +6,59 @@ module Langchain::Tool
6
6
 
7
7
  # How to add additional Tools?
8
8
  # 1. Create a new file in lib/tool/your_tool_name.rb
9
- # 2. Add your tool to the TOOLS hash below
10
- # "your_tool_name" => "Tool::YourToolName"
11
- # 3. Implement `self.execute(input:)` method in your tool class
12
- # 4. Add your tool to the README.md
9
+ # 2. Create a class in the file that inherits from Langchain::Tool::Base
10
+ # 3. Add `NAME=` and `DESCRIPTION=` constants in your Tool class
11
+ # 4. Implement `execute(input:)` method in your tool class
12
+ # 5. Add your tool to the README.md
13
13
 
14
- TOOLS = {
15
- "calculator" => "Langchain::Tool::Calculator",
16
- "search" => "Langchain::Tool::SerpApi",
17
- "wikipedia" => "Langchain::Tool::Wikipedia",
18
- "database" => "Langchain::Tool::Database"
19
- }
14
+ #
15
+ # Returns the NAME constant of the tool
16
+ #
17
+ # @return [String] tool name
18
+ #
19
+ def tool_name
20
+ self.class.const_get(:NAME)
21
+ end
20
22
 
23
+ #
24
+ # Sets the DESCRIPTION constant of the tool
25
+ #
26
+ # @param value [String] tool description
27
+ #
21
28
  def self.description(value)
22
29
  const_set(:DESCRIPTION, value.tr("\n", " ").strip)
23
30
  end
24
31
 
32
+ #
25
33
  # Instantiates and executes the tool and returns the answer
34
+ #
26
35
  # @param input [String] input to the tool
27
36
  # @return [String] answer
37
+ #
28
38
  def self.execute(input:)
29
39
  new.execute(input: input)
30
40
  end
31
41
 
42
+ #
32
43
  # Executes the tool and returns the answer
44
+ #
33
45
  # @param input [String] input to the tool
34
46
  # @return [String] answer
47
+ #
35
48
  def execute(input:)
36
49
  raise NotImplementedError, "Your tool must implement the `#execute(input:)` method that returns a string"
37
50
  end
38
51
 
39
52
  #
40
- # Validates the list of strings (tools) are all supported or raises an error
41
- # @param tools [Array<String>] list of tools to be used
53
+ # Validates the list of tools or raises an error
54
+ # @param tools [Array<Langchain::Tool>] list of tools to be used
42
55
  #
43
56
  # @raise [ArgumentError] If any of the tools are not supported
44
57
  #
45
58
  def self.validate_tools!(tools:)
46
- unrecognized_tools = tools - Langchain::Tool::Base::TOOLS.keys
47
-
48
- if unrecognized_tools.any?
49
- raise ArgumentError, "Unrecognized Tools: #{unrecognized_tools}"
59
+ # Check if the tool count is equal to unique tool count
60
+ if tools.count != tools.map(&:tool_name).uniq.count
61
+ raise ArgumentError, "Either tools are not unique or are conflicting with each other"
50
62
  end
51
63
  end
52
64
  end
@@ -8,9 +8,10 @@ module Langchain::Tool
8
8
  # Gem requirements:
9
9
  # gem "eqn", "~> 1.6.5"
10
10
  # gem "google_search_results", "~> 2.0.0"
11
- # ENV requirements: ENV["SERPAPI_API_KEY"]
12
11
  #
13
12
 
13
+ NAME = "calculator"
14
+
14
15
  description <<~DESC
15
16
  Useful for getting the result of a math expression.
16
17
 
@@ -33,8 +34,12 @@ module Langchain::Tool
33
34
  rescue Eqn::ParseError, Eqn::NoVariableValueError
34
35
  # Sometimes the input is not a pure math expression, e.g: "12F in Celsius"
35
36
  # We can use the google answer box to evaluate this expression
36
- hash_results = Langchain::Tool::SerpApi.execute_search(input: input)
37
- hash_results.dig(:answer_box, :to)
37
+ # TODO: Figure out to find a better way to evaluate these language expressions.
38
+ hash_results = Langchain::Tool::SerpApi
39
+ .new(api_key: ENV["SERPAPI_API_KEY"])
40
+ .execute_search(input: input)
41
+ hash_results.dig(:answer_box, :to) ||
42
+ hash_results.dig(:answer_box, :result)
38
43
  end
39
44
  end
40
45
  end
@@ -6,40 +6,55 @@ module Langchain::Tool
6
6
  # Gem requirements: gem "sequel", "~> 5.68.0"
7
7
  #
8
8
 
9
+ NAME = "database"
10
+
9
11
  description <<~DESC
10
12
  Useful for getting the result of a database query.
11
13
 
12
14
  The input to this tool should be valid SQL.
13
15
  DESC
14
16
 
17
+ attr_reader :db
18
+
19
+ #
15
20
  # Establish a database connection
16
- # @param db_connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
17
- def initialize(db_connection_string)
21
+ #
22
+ # @param connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
23
+ # @return [Database] Database object
24
+ #
25
+ def initialize(connection_string:)
18
26
  depends_on "sequel"
19
27
  require "sequel"
20
28
  require "sequel/extensions/schema_dumper"
21
29
 
22
- raise StandardError, "db_connection_string parameter cannot be blank" if db_connection_string.empty?
30
+ raise StandardError, "connection_string parameter cannot be blank" if connection_string.empty?
23
31
 
24
- @db = Sequel.connect(db_connection_string)
32
+ @db = Sequel.connect(connection_string)
25
33
  @db.extension :schema_dumper
26
34
  end
27
35
 
36
+ #
37
+ # Returns the database schema
38
+ #
39
+ # @return [String] schema
40
+ #
28
41
  def schema
29
42
  Langchain.logger.info("[#{self.class.name}]".light_blue + ": Dumping schema")
30
- @db.dump_schema_migration(same_db: true, indexes: false) unless @db.adapter_scheme == :mock
43
+ db.dump_schema_migration(same_db: true, indexes: false) unless db.adapter_scheme == :mock
31
44
  end
32
45
 
46
+ #
33
47
  # Evaluates a sql expression
48
+ #
34
49
  # @param input [String] sql expression
35
50
  # @return [Array] results
51
+ #
36
52
  def execute(input:)
37
53
  Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
38
- begin
39
- @db[input].to_a
40
- rescue Sequel::DatabaseError => e
41
- Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
42
- end
54
+
55
+ db[input].to_a
56
+ rescue Sequel::DatabaseError => e
57
+ Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
43
58
  end
44
59
  end
45
60
  end
@@ -7,7 +7,7 @@ module Langchain::Tool
7
7
  #
8
8
  # Gem requirements: gem "safe_ruby", "~> 1.0.4"
9
9
  #
10
-
10
+ NAME = "ruby_code_interpreter"
11
11
  description <<~DESC
12
12
  A Ruby code interpreter. Use this to execute ruby expressions. Input should be a valid ruby expression. If you want to see the output of the tool, make sure to return a value.
13
13
  DESC
@@ -6,8 +6,13 @@ module Langchain::Tool
6
6
  # Wrapper around SerpAPI
7
7
  #
8
8
  # Gem requirements: gem "google_search_results", "~> 2.0.0"
9
- # ENV requirements: ENV["SERPAPI_API_KEY"] # https://serpapi.com/manage-api-key)
10
9
  #
10
+ # Usage:
11
+ # search = Langchain::Tool::SerpApi.new(api_key: "YOUR_API_KEY")
12
+ # search.execute(input: "What is the capital of France?")
13
+ #
14
+
15
+ NAME = "search"
11
16
 
12
17
  description <<~DESC
13
18
  A wrapper around Google Search.
@@ -18,39 +23,57 @@ module Langchain::Tool
18
23
  Input should be a search query.
19
24
  DESC
20
25
 
21
- def initialize
26
+ attr_reader :api_key
27
+
28
+ #
29
+ # Initializes the SerpAPI tool
30
+ #
31
+ # @param api_key [String] SerpAPI API key
32
+ # @return [Langchain::Tool::SerpApi] SerpAPI tool
33
+ #
34
+ def initialize(api_key:)
22
35
  depends_on "google_search_results"
23
36
  require "google_search_results"
37
+ @api_key = api_key
24
38
  end
25
39
 
40
+ #
26
41
  # Executes Google Search and returns hash_results JSON
42
+ #
27
43
  # @param input [String] search query
28
44
  # @return [Hash] hash_results JSON
29
-
45
+ #
30
46
  def self.execute_search(input:)
31
47
  new.execute_search(input: input)
32
48
  end
33
49
 
34
- # Executes Google Search and returns hash_results JSON
50
+ #
51
+ # Executes Google Search and returns the result
52
+ #
35
53
  # @param input [String] search query
36
54
  # @return [String] Answer
37
- # TODO: Glance at all of the fields that langchain Python looks through: https://github.com/hwchase17/langchain/blob/v0.0.166/langchain/utilities/serpapi.py#L128-L156
38
- # We may need to do the same thing here.
55
+ #
39
56
  def execute(input:)
40
57
  Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
41
58
 
42
59
  hash_results = execute_search(input: input)
43
60
 
61
+ # TODO: Glance at all of the fields that langchain Python looks through: https://github.com/hwchase17/langchain/blob/v0.0.166/langchain/utilities/serpapi.py#L128-L156
62
+ # We may need to do the same thing here.
44
63
  hash_results.dig(:answer_box, :answer) ||
45
64
  hash_results.dig(:answer_box, :snippet) ||
46
65
  hash_results.dig(:organic_results, 0, :snippet)
47
66
  end
48
67
 
68
+ #
69
+ # Executes Google Search and returns hash_results JSON
70
+ #
71
+ # @param input [String] search query
72
+ # @return [Hash] hash_results JSON
73
+ #
49
74
  def execute_search(input:)
50
- GoogleSearch.new(
51
- q: input,
52
- serp_api_key: ENV["SERPAPI_API_KEY"]
53
- )
75
+ GoogleSearch
76
+ .new(q: input, serp_api_key: api_key)
54
77
  .get_hash
55
78
  end
56
79
  end
@@ -7,7 +7,7 @@ module Langchain::Tool
7
7
  #
8
8
  # Gem requirements: gem "wikipedia-client", "~> 1.17.0"
9
9
  #
10
-
10
+ NAME = "wikipedia"
11
11
  description <<~DESC
12
12
  A wrapper around Wikipedia.
13
13
 
@@ -34,23 +34,50 @@ module Langchain
34
34
  "ada" => 2049
35
35
  }.freeze
36
36
 
37
+ # GOOGLE_PALM_TOKEN_LIMITS = {
38
+ # "chat-bison-001" => {
39
+ # "inputTokenLimit"=>4096,
40
+ # "outputTokenLimit"=>1024
41
+ # },
42
+ # "text-bison-001" => {
43
+ # "inputTokenLimit"=>8196,
44
+ # "outputTokenLimit"=>1024
45
+ # },
46
+ # "embedding-gecko-001" => {
47
+ # "inputTokenLimit"=>1024
48
+ # }
49
+ # }.freeze
50
+
37
51
  #
38
- # Validate the length of the text passed in to OpenAI's API
52
+ # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
39
53
  #
40
54
  # @param text [String] The text to validate
41
55
  # @param model_name [String] The model name to validate against
42
- # @return [Boolean] Whether the text is valid or not
56
+ # @return [Integer] Whether the text is valid or not
43
57
  # @raise [TokenLimitExceeded] If the text is too long
44
58
  #
45
- def self.validate!(text, model_name)
46
- encoder = Tiktoken.encoding_for_model(model_name)
47
- token_length = encoder.encode(text).length
59
+ def self.validate_max_tokens!(text, model_name)
60
+ text_token_length = token_length(text, model_name)
61
+ max_tokens = TOKEN_LIMITS[model_name] - text_token_length
48
62
 
49
- if token_length > TOKEN_LIMITS[model_name]
50
- raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{token_length} tokens long."
63
+ # Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
64
+ if max_tokens <= 0
65
+ raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
51
66
  end
52
67
 
53
- true
68
+ max_tokens
69
+ end
70
+
71
+ #
72
+ # Calculate token length for a given text and model name
73
+ #
74
+ # @param text [String] The text to validate
75
+ # @param model_name [String] The model name to validate against
76
+ # @return [Integer] The token length of the text
77
+ #
78
+ def self.token_length(text, model_name)
79
+ encoder = Tiktoken.encoding_for_model(model_name)
80
+ encoder.encode(text).length
54
81
  end
55
82
  end
56
83
  end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.5.0"
4
+ VERSION = "0.5.2"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: langchainrb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.0
4
+ version: 0.5.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrei Bondarev
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-06-05 00:00:00.000000000 Z
11
+ date: 2023-06-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: tiktoken_ruby