langchainrb 0.5.0 → 0.5.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b0a2fe8026e861c9d97465bce7da08a0b077492d6f7cf8fb42c45dbfdfe6749f
4
- data.tar.gz: c04099c44a847bd9c05e8594859f92ca1f54d338c463ce59a375c2cb9731b1ad
3
+ metadata.gz: d36de4206b792714ba9b6773c03272e9638b14caf7140e0bc00c3e767aa5fdef
4
+ data.tar.gz: 819fab9de55a34e4e6dc865febc19bb9979df55fa8fc6a753774cf1961c40103
5
5
  SHA512:
6
- metadata.gz: dec375b2b7cae377cf31f3f8ed0a6ac9d79215c945e7c0da78ed1fbad3c502ecfcc5ce5318c55a9a634db88fea1f5fbbeed7a0f7dc6ab8096c909e0a3ff02154
7
- data.tar.gz: 558a0f6ddf90ad044f9e2cc7c6ca678958472748d2e430a3cbb4308290898b9b22a204a94a5b5943f06137f7da5238d038a65b8c79d96f8a3705499d95cfb597
6
+ metadata.gz: 6e180b41bbca96bd5523c276923f223bbebe470314086c6a909df440890793bcc70dbd66ecf59bf5d0fd52426650cc5d2684c56cc8fc643209cc1679527cbef4
7
+ data.tar.gz: af5db76c2b22b5c7bdc1170de437921e8464a16566f46a5cad465d69e6da47c97a82f7331a5ea5747840e58acc71463aa8456b03e9bc8851efda7b734e5d23cc
data/CHANGELOG.md CHANGED
@@ -1,5 +1,13 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.5.2] - 2023-06-07
4
+ - 🗣️ LLMs
5
+ - Auto-calculate the max_tokens: setting to be passed on to OpenAI
6
+
7
+ ## [0.5.1] - 2023-06-06
8
+ - 🛠️ Tools
9
+ - Modified Tool usage. Agents now accept Tools instances instead of Tool strings.
10
+
3
11
  ## [0.5.0] - 2023-06-05
4
12
  - [BREAKING] LLMs are now passed as objects to Vectorsearch classes instead of `llm: :name, llm_api_key:` previously
5
13
  - 📋 Prompts
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- langchainrb (0.5.0)
4
+ langchainrb (0.5.2)
5
5
  colorize (~> 0.8.1)
6
6
  tiktoken_ruby (~> 0.0.5)
7
7
 
data/README.md CHANGED
@@ -256,7 +256,15 @@ Agents are semi-autonomous bots that can respond to user questions and use avail
256
256
  Add `gem "ruby-openai"`, `gem "eqn"`, and `gem "google_search_results"` to your Gemfile
257
257
 
258
258
  ```ruby
259
- agent = Langchain::Agent::ChainOfThoughtAgent.new(llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"]), tools: ['search', 'calculator'])
259
+ search_tool = Langchain::Tool::SerpApi.new(api_key: ENV["SERPAPI_API_KEY"])
260
+ calculator = Langchain::Tool::Calculator.new
261
+
262
+ openai = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
263
+
264
+ agent = Langchain::Agent::ChainOfThoughtAgent.new(
265
+ llm: openai,
266
+ tools: [search_tool, calculator]
267
+ )
260
268
 
261
269
  agent.tools
262
270
  # => ["search", "calculator"]
@@ -271,11 +279,12 @@ agent.run(question: "How many full soccer fields would be needed to cover the di
271
279
  Add `gem "sequel"` to your Gemfile
272
280
 
273
281
  ```ruby
274
- agent = Langchain::Agent::SQLQueryAgent.new(llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"]), db_connection_string: "postgres://user:password@localhost:5432/db_name")
282
+ database = Langchain::Tool::Database.new(connection_string: "postgres://user:password@localhost:5432/db_name")
275
283
 
284
+ agent = Langchain::Agent::SQLQueryAgent.new(llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"]), db: database)
276
285
  ```
277
286
  ```ruby
278
- agent.ask(question: "How many users have a name with length greater than 5 in the users table?")
287
+ agent.run(question: "How many users have a name with length greater than 5 in the users table?")
279
288
  #=> "14 users have a name with length greater than 5 in the users table."
280
289
  ```
281
290
 
@@ -1,10 +1,10 @@
1
1
  require "langchain"
2
2
 
3
3
  # Create a prompt with a few shot examples
4
- prompt = Prompt::FewShotPromptTemplate.new(
4
+ prompt = Langchain::Prompt::FewShotPromptTemplate.new(
5
5
  prefix: "Write antonyms for the following words.",
6
6
  suffix: "Input: {adjective}\nOutput:",
7
- example_prompt: Prompt::PromptTemplate.new(
7
+ example_prompt: Langchain::Prompt::PromptTemplate.new(
8
8
  input_variables: ["input", "output"],
9
9
  template: "Input: {input}\nOutput: {output}"
10
10
  ),
@@ -32,5 +32,5 @@ prompt.format(adjective: "good")
32
32
  prompt.save(file_path: "spec/fixtures/prompt/few_shot_prompt_template.json")
33
33
 
34
34
  # Loading a new prompt template using a JSON file
35
- prompt = Prompt.load_from_path(file_path: "spec/fixtures/prompt/few_shot_prompt_template.json")
35
+ prompt = Langchain::Prompt.load_from_path(file_path: "spec/fixtures/prompt/few_shot_prompt_template.json")
36
36
  prompt.prefix # "Write antonyms for the following words."
@@ -1,15 +1,15 @@
1
1
  require "langchain"
2
2
 
3
3
  # Create a prompt with one input variable
4
- prompt = Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke.", input_variables: ["adjective"])
4
+ prompt = Langchain::Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke.", input_variables: ["adjective"])
5
5
  prompt.format(adjective: "funny") # "Tell me a funny joke."
6
6
 
7
7
  # Create a prompt with multiple input variables
8
- prompt = Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke about {content}.", input_variables: ["adjective", "content"])
8
+ prompt = Langchain::Prompt::PromptTemplate.new(template: "Tell me a {adjective} joke about {content}.", input_variables: ["adjective", "content"])
9
9
  prompt.format(adjective: "funny", content: "chickens") # "Tell me a funny joke about chickens."
10
10
 
11
11
  # Creating a PromptTemplate using just a prompt and no input_variables
12
- prompt = Prompt::PromptTemplate.from_template("Tell me a {adjective} joke about {content}.")
12
+ prompt = Langchain::Prompt::PromptTemplate.from_template("Tell me a {adjective} joke about {content}.")
13
13
  prompt.input_variables # ["adjective", "content"]
14
14
  prompt.format(adjective: "funny", content: "chickens") # "Tell me a funny joke about chickens."
15
15
 
@@ -17,5 +17,9 @@ prompt.format(adjective: "funny", content: "chickens") # "Tell me a funny joke a
17
17
  prompt.save(file_path: "spec/fixtures/prompt/prompt_template.json")
18
18
 
19
19
  # Loading a new prompt template using a JSON file
20
- prompt = Prompt.load_from_path(file_path: "spec/fixtures/prompt/prompt_template.json")
20
+ prompt = Langchain::Prompt.load_from_path(file_path: "spec/fixtures/prompt/prompt_template.json")
21
+ prompt.input_variables # ["adjective", "content"]
22
+
23
+ # Loading a new prompt template using a YAML file
24
+ prompt = Langchain::Prompt.load_from_path(file_path: "spec/fixtures/prompt/prompt_template.yaml")
21
25
  prompt.input_variables # ["adjective", "content"]
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "chroma-db", "~> 0.3.0"` to your Gemfile
5
5
 
6
6
  # Instantiate the Chroma client
7
- chroma = Vectorsearch::Chroma.new(
7
+ chroma = Langchain::Vectorsearch::Chroma.new(
8
8
  url: ENV["CHROMA_URL"],
9
9
  index_name: "documents",
10
10
  llm: Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "pinecone"` to your Gemfile
5
5
 
6
6
  # Instantiate the Qdrant client
7
- pinecone = Vectorsearch::Pinecone.new(
7
+ pinecone = Langchain::Vectorsearch::Pinecone.new(
8
8
  environment: ENV["PINECONE_ENVIRONMENT"],
9
9
  api_key: ENV["PINECONE_API_KEY"],
10
10
  index_name: "recipes",
@@ -37,7 +37,7 @@ pinecone.ask(
37
37
  )
38
38
 
39
39
  # Generate your an embedding and search by it
40
- openai = LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
40
+ openai = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
41
41
  embedding = openai.embed(text: "veggie")
42
42
 
43
43
  pinecone.similarity_search_by_vector(
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "qdrant-ruby"` to your Gemfile
5
5
 
6
6
  # Instantiate the Qdrant client
7
- qdrant = Vectorsearch::Qdrant.new(
7
+ qdrant = Langchain::Vectorsearch::Qdrant.new(
8
8
  url: ENV["QDRANT_URL"],
9
9
  api_key: ENV["QDRANT_API_KEY"],
10
10
  index_name: "recipes",
@@ -4,7 +4,7 @@ require "langchain"
4
4
  # or add `gem "weaviate-ruby"` to your Gemfile
5
5
 
6
6
  # Instantiate the Weaviate client
7
- weaviate = Vectorsearch::Weaviate.new(
7
+ weaviate = Langchain::Vectorsearch::Weaviate.new(
8
8
  url: ENV["WEAVIATE_URL"],
9
9
  api_key: ENV["WEAVIATE_API_KEY"],
10
10
  index_name: "Recipes",
@@ -39,11 +39,8 @@ module Langchain::Agent
39
39
 
40
40
  loop do
41
41
  Langchain.logger.info("[#{self.class.name}]".red + ": Sending the prompt to the #{llm.class} LLM")
42
- response = llm.complete(
43
- prompt: prompt,
44
- stop_sequences: ["Observation:"],
45
- max_tokens: 500
46
- )
42
+
43
+ response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
47
44
 
48
45
  # Append the response to the prompt
49
46
  prompt += response
@@ -55,10 +52,11 @@ module Langchain::Agent
55
52
  # Find the input to the action in the "Action Input: [action_input]" format
56
53
  action_input = response.match(/Action Input: "?(.*)"?/)&.send(:[], -1)
57
54
 
58
- # Retrieve the Tool::[ToolName] class and call `execute`` with action_input as the input
59
- tool = Langchain::Tool.const_get(Langchain::Tool::Base::TOOLS[action.strip])
60
- Langchain.logger.info("[#{self.class.name}]".red + ": Invoking \"#{tool}\" Tool with \"#{action_input}\"")
55
+ # Find the Tool and call `execute`` with action_input as the input
56
+ tool = tools.find { |tool| tool.tool_name == action.strip }
57
+ Langchain.logger.info("[#{self.class.name}]".red + ": Invoking \"#{tool.class}\" Tool with \"#{action_input}\"")
61
58
 
59
+ # Call `execute` with action_input as the input
62
60
  result = tool.execute(input: action_input)
63
61
 
64
62
  # Append the Observation to the prompt
@@ -81,12 +79,16 @@ module Langchain::Agent
81
79
  # @param tools [Array] Tools to use
82
80
  # @return [String] Prompt
83
81
  def create_prompt(question:, tools:)
82
+ tool_list = tools.map(&:tool_name)
83
+
84
84
  prompt_template.format(
85
85
  date: Date.today.strftime("%B %d, %Y"),
86
86
  question: question,
87
- tool_names: "[#{tools.join(", ")}]",
87
+ tool_names: "[#{tool_list.join(", ")}]",
88
88
  tools: tools.map do |tool|
89
- "#{tool}: #{Langchain::Tool.const_get(Langchain::Tool::Base::TOOLS[tool]).const_get(:DESCRIPTION)}"
89
+ tool_name = tool.tool_name
90
+ tool_description = tool.class.const_get(:DESCRIPTION)
91
+ "#{tool_name}: #{tool_description}"
90
92
  end.join("\n")
91
93
  )
92
94
  end
@@ -4,26 +4,30 @@ module Langchain::Agent
4
4
  class SQLQueryAgent < Base
5
5
  attr_reader :llm, :db, :schema
6
6
 
7
+ #
7
8
  # Initializes the Agent
8
9
  #
9
10
  # @param llm [Object] The LLM client to use
10
- # @param db_connection_string [String] Database connection info
11
- def initialize(llm:, db_connection_string:)
11
+ # @param db [Object] Database connection info
12
+ #
13
+ def initialize(llm:, db:)
12
14
  @llm = llm
13
- @db = Langchain::Tool::Database.new(db_connection_string)
15
+ @db = db
14
16
  @schema = @db.schema
15
17
  end
16
18
 
19
+ #
17
20
  # Ask a question and get an answer
18
21
  #
19
22
  # @param question [String] Question to ask the LLM/Database
20
23
  # @return [String] Answer to the question
21
- def ask(question:)
24
+ #
25
+ def run(question:)
22
26
  prompt = create_prompt_for_sql(question: question)
23
27
 
24
28
  # Get the SQL string to execute
25
29
  Langchain.logger.info("[#{self.class.name}]".red + ": Passing the inital prompt to the #{llm.class} LLM")
26
- sql_string = llm.complete(prompt: prompt, max_tokens: 500)
30
+ sql_string = llm.complete(prompt: prompt)
27
31
 
28
32
  # Execute the SQL string and collect the results
29
33
  Langchain.logger.info("[#{self.class.name}]".red + ": Passing the SQL to the Database: #{sql_string}")
@@ -32,7 +36,7 @@ module Langchain::Agent
32
36
  # Pass the results and get the LLM to synthesize the answer to the question
33
37
  Langchain.logger.info("[#{self.class.name}]".red + ": Passing the synthesize prompt to the #{llm.class} LLM with results: #{results}")
34
38
  prompt2 = create_prompt_for_answer(question: question, sql_query: sql_string, results: results)
35
- llm.complete(prompt: prompt2, max_tokens: 500)
39
+ llm.complete(prompt: prompt2)
36
40
  end
37
41
 
38
42
  private
@@ -35,7 +35,7 @@ module Langchain::LLM
35
35
  def embed(text:, **params)
36
36
  parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
37
37
 
38
- Langchain::Utils::TokenLengthValidator.validate!(text, parameters[:model])
38
+ Langchain::Utils::TokenLengthValidator.validate_max_tokens!(text, parameters[:model])
39
39
 
40
40
  response = client.embeddings(parameters: parameters.merge(params))
41
41
  response.dig("data").first.dig("embedding")
@@ -50,9 +50,8 @@ module Langchain::LLM
50
50
  def complete(prompt:, **params)
51
51
  parameters = compose_parameters DEFAULTS[:completion_model_name], params
52
52
 
53
- Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
54
-
55
53
  parameters[:prompt] = prompt
54
+ parameters[:max_tokens] = Langchain::Utils::TokenLengthValidator.validate_max_tokens!(prompt, parameters[:model])
56
55
 
57
56
  response = client.completions(parameters: parameters)
58
57
  response.dig("choices", 0, "text")
@@ -67,9 +66,8 @@ module Langchain::LLM
67
66
  def chat(prompt:, **params)
68
67
  parameters = compose_parameters DEFAULTS[:chat_completion_model_name], params
69
68
 
70
- Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
71
-
72
69
  parameters[:messages] = [{role: "user", content: prompt}]
70
+ parameters[:max_tokens] = Langchain::Utils::TokenLengthValidator.validate_max_tokens!(prompt, parameters[:model])
73
71
 
74
72
  response = client.chat(parameters: parameters)
75
73
  response.dig("choices", 0, "message", "content")
@@ -87,12 +85,7 @@ module Langchain::LLM
87
85
  )
88
86
  prompt = prompt_template.format(text: text)
89
87
 
90
- complete(
91
- prompt: prompt,
92
- temperature: DEFAULTS[:temperature],
93
- # Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
94
- max_tokens: 2048
95
- )
88
+ complete(prompt: prompt, temperature: DEFAULTS[:temperature])
96
89
  end
97
90
 
98
91
  private
@@ -6,47 +6,59 @@ module Langchain::Tool
6
6
 
7
7
  # How to add additional Tools?
8
8
  # 1. Create a new file in lib/tool/your_tool_name.rb
9
- # 2. Add your tool to the TOOLS hash below
10
- # "your_tool_name" => "Tool::YourToolName"
11
- # 3. Implement `self.execute(input:)` method in your tool class
12
- # 4. Add your tool to the README.md
9
+ # 2. Create a class in the file that inherits from Langchain::Tool::Base
10
+ # 3. Add `NAME=` and `DESCRIPTION=` constants in your Tool class
11
+ # 4. Implement `execute(input:)` method in your tool class
12
+ # 5. Add your tool to the README.md
13
13
 
14
- TOOLS = {
15
- "calculator" => "Langchain::Tool::Calculator",
16
- "search" => "Langchain::Tool::SerpApi",
17
- "wikipedia" => "Langchain::Tool::Wikipedia",
18
- "database" => "Langchain::Tool::Database"
19
- }
14
+ #
15
+ # Returns the NAME constant of the tool
16
+ #
17
+ # @return [String] tool name
18
+ #
19
+ def tool_name
20
+ self.class.const_get(:NAME)
21
+ end
20
22
 
23
+ #
24
+ # Sets the DESCRIPTION constant of the tool
25
+ #
26
+ # @param value [String] tool description
27
+ #
21
28
  def self.description(value)
22
29
  const_set(:DESCRIPTION, value.tr("\n", " ").strip)
23
30
  end
24
31
 
32
+ #
25
33
  # Instantiates and executes the tool and returns the answer
34
+ #
26
35
  # @param input [String] input to the tool
27
36
  # @return [String] answer
37
+ #
28
38
  def self.execute(input:)
29
39
  new.execute(input: input)
30
40
  end
31
41
 
42
+ #
32
43
  # Executes the tool and returns the answer
44
+ #
33
45
  # @param input [String] input to the tool
34
46
  # @return [String] answer
47
+ #
35
48
  def execute(input:)
36
49
  raise NotImplementedError, "Your tool must implement the `#execute(input:)` method that returns a string"
37
50
  end
38
51
 
39
52
  #
40
- # Validates the list of strings (tools) are all supported or raises an error
41
- # @param tools [Array<String>] list of tools to be used
53
+ # Validates the list of tools or raises an error
54
+ # @param tools [Array<Langchain::Tool>] list of tools to be used
42
55
  #
43
56
  # @raise [ArgumentError] If any of the tools are not supported
44
57
  #
45
58
  def self.validate_tools!(tools:)
46
- unrecognized_tools = tools - Langchain::Tool::Base::TOOLS.keys
47
-
48
- if unrecognized_tools.any?
49
- raise ArgumentError, "Unrecognized Tools: #{unrecognized_tools}"
59
+ # Check if the tool count is equal to unique tool count
60
+ if tools.count != tools.map(&:tool_name).uniq.count
61
+ raise ArgumentError, "Either tools are not unique or are conflicting with each other"
50
62
  end
51
63
  end
52
64
  end
@@ -8,9 +8,10 @@ module Langchain::Tool
8
8
  # Gem requirements:
9
9
  # gem "eqn", "~> 1.6.5"
10
10
  # gem "google_search_results", "~> 2.0.0"
11
- # ENV requirements: ENV["SERPAPI_API_KEY"]
12
11
  #
13
12
 
13
+ NAME = "calculator"
14
+
14
15
  description <<~DESC
15
16
  Useful for getting the result of a math expression.
16
17
 
@@ -33,8 +34,12 @@ module Langchain::Tool
33
34
  rescue Eqn::ParseError, Eqn::NoVariableValueError
34
35
  # Sometimes the input is not a pure math expression, e.g: "12F in Celsius"
35
36
  # We can use the google answer box to evaluate this expression
36
- hash_results = Langchain::Tool::SerpApi.execute_search(input: input)
37
- hash_results.dig(:answer_box, :to)
37
+ # TODO: Figure out to find a better way to evaluate these language expressions.
38
+ hash_results = Langchain::Tool::SerpApi
39
+ .new(api_key: ENV["SERPAPI_API_KEY"])
40
+ .execute_search(input: input)
41
+ hash_results.dig(:answer_box, :to) ||
42
+ hash_results.dig(:answer_box, :result)
38
43
  end
39
44
  end
40
45
  end
@@ -6,40 +6,55 @@ module Langchain::Tool
6
6
  # Gem requirements: gem "sequel", "~> 5.68.0"
7
7
  #
8
8
 
9
+ NAME = "database"
10
+
9
11
  description <<~DESC
10
12
  Useful for getting the result of a database query.
11
13
 
12
14
  The input to this tool should be valid SQL.
13
15
  DESC
14
16
 
17
+ attr_reader :db
18
+
19
+ #
15
20
  # Establish a database connection
16
- # @param db_connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
17
- def initialize(db_connection_string)
21
+ #
22
+ # @param connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
23
+ # @return [Database] Database object
24
+ #
25
+ def initialize(connection_string:)
18
26
  depends_on "sequel"
19
27
  require "sequel"
20
28
  require "sequel/extensions/schema_dumper"
21
29
 
22
- raise StandardError, "db_connection_string parameter cannot be blank" if db_connection_string.empty?
30
+ raise StandardError, "connection_string parameter cannot be blank" if connection_string.empty?
23
31
 
24
- @db = Sequel.connect(db_connection_string)
32
+ @db = Sequel.connect(connection_string)
25
33
  @db.extension :schema_dumper
26
34
  end
27
35
 
36
+ #
37
+ # Returns the database schema
38
+ #
39
+ # @return [String] schema
40
+ #
28
41
  def schema
29
42
  Langchain.logger.info("[#{self.class.name}]".light_blue + ": Dumping schema")
30
- @db.dump_schema_migration(same_db: true, indexes: false) unless @db.adapter_scheme == :mock
43
+ db.dump_schema_migration(same_db: true, indexes: false) unless db.adapter_scheme == :mock
31
44
  end
32
45
 
46
+ #
33
47
  # Evaluates a sql expression
48
+ #
34
49
  # @param input [String] sql expression
35
50
  # @return [Array] results
51
+ #
36
52
  def execute(input:)
37
53
  Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
38
- begin
39
- @db[input].to_a
40
- rescue Sequel::DatabaseError => e
41
- Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
42
- end
54
+
55
+ db[input].to_a
56
+ rescue Sequel::DatabaseError => e
57
+ Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
43
58
  end
44
59
  end
45
60
  end
@@ -7,7 +7,7 @@ module Langchain::Tool
7
7
  #
8
8
  # Gem requirements: gem "safe_ruby", "~> 1.0.4"
9
9
  #
10
-
10
+ NAME = "ruby_code_interpreter"
11
11
  description <<~DESC
12
12
  A Ruby code interpreter. Use this to execute ruby expressions. Input should be a valid ruby expression. If you want to see the output of the tool, make sure to return a value.
13
13
  DESC
@@ -6,8 +6,13 @@ module Langchain::Tool
6
6
  # Wrapper around SerpAPI
7
7
  #
8
8
  # Gem requirements: gem "google_search_results", "~> 2.0.0"
9
- # ENV requirements: ENV["SERPAPI_API_KEY"] # https://serpapi.com/manage-api-key)
10
9
  #
10
+ # Usage:
11
+ # search = Langchain::Tool::SerpApi.new(api_key: "YOUR_API_KEY")
12
+ # search.execute(input: "What is the capital of France?")
13
+ #
14
+
15
+ NAME = "search"
11
16
 
12
17
  description <<~DESC
13
18
  A wrapper around Google Search.
@@ -18,39 +23,57 @@ module Langchain::Tool
18
23
  Input should be a search query.
19
24
  DESC
20
25
 
21
- def initialize
26
+ attr_reader :api_key
27
+
28
+ #
29
+ # Initializes the SerpAPI tool
30
+ #
31
+ # @param api_key [String] SerpAPI API key
32
+ # @return [Langchain::Tool::SerpApi] SerpAPI tool
33
+ #
34
+ def initialize(api_key:)
22
35
  depends_on "google_search_results"
23
36
  require "google_search_results"
37
+ @api_key = api_key
24
38
  end
25
39
 
40
+ #
26
41
  # Executes Google Search and returns hash_results JSON
42
+ #
27
43
  # @param input [String] search query
28
44
  # @return [Hash] hash_results JSON
29
-
45
+ #
30
46
  def self.execute_search(input:)
31
47
  new.execute_search(input: input)
32
48
  end
33
49
 
34
- # Executes Google Search and returns hash_results JSON
50
+ #
51
+ # Executes Google Search and returns the result
52
+ #
35
53
  # @param input [String] search query
36
54
  # @return [String] Answer
37
- # TODO: Glance at all of the fields that langchain Python looks through: https://github.com/hwchase17/langchain/blob/v0.0.166/langchain/utilities/serpapi.py#L128-L156
38
- # We may need to do the same thing here.
55
+ #
39
56
  def execute(input:)
40
57
  Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
41
58
 
42
59
  hash_results = execute_search(input: input)
43
60
 
61
+ # TODO: Glance at all of the fields that langchain Python looks through: https://github.com/hwchase17/langchain/blob/v0.0.166/langchain/utilities/serpapi.py#L128-L156
62
+ # We may need to do the same thing here.
44
63
  hash_results.dig(:answer_box, :answer) ||
45
64
  hash_results.dig(:answer_box, :snippet) ||
46
65
  hash_results.dig(:organic_results, 0, :snippet)
47
66
  end
48
67
 
68
+ #
69
+ # Executes Google Search and returns hash_results JSON
70
+ #
71
+ # @param input [String] search query
72
+ # @return [Hash] hash_results JSON
73
+ #
49
74
  def execute_search(input:)
50
- GoogleSearch.new(
51
- q: input,
52
- serp_api_key: ENV["SERPAPI_API_KEY"]
53
- )
75
+ GoogleSearch
76
+ .new(q: input, serp_api_key: api_key)
54
77
  .get_hash
55
78
  end
56
79
  end
@@ -7,7 +7,7 @@ module Langchain::Tool
7
7
  #
8
8
  # Gem requirements: gem "wikipedia-client", "~> 1.17.0"
9
9
  #
10
-
10
+ NAME = "wikipedia"
11
11
  description <<~DESC
12
12
  A wrapper around Wikipedia.
13
13
 
@@ -34,23 +34,50 @@ module Langchain
34
34
  "ada" => 2049
35
35
  }.freeze
36
36
 
37
+ # GOOGLE_PALM_TOKEN_LIMITS = {
38
+ # "chat-bison-001" => {
39
+ # "inputTokenLimit"=>4096,
40
+ # "outputTokenLimit"=>1024
41
+ # },
42
+ # "text-bison-001" => {
43
+ # "inputTokenLimit"=>8196,
44
+ # "outputTokenLimit"=>1024
45
+ # },
46
+ # "embedding-gecko-001" => {
47
+ # "inputTokenLimit"=>1024
48
+ # }
49
+ # }.freeze
50
+
37
51
  #
38
- # Validate the length of the text passed in to OpenAI's API
52
+ # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
39
53
  #
40
54
  # @param text [String] The text to validate
41
55
  # @param model_name [String] The model name to validate against
42
- # @return [Boolean] Whether the text is valid or not
56
+ # @return [Integer] Whether the text is valid or not
43
57
  # @raise [TokenLimitExceeded] If the text is too long
44
58
  #
45
- def self.validate!(text, model_name)
46
- encoder = Tiktoken.encoding_for_model(model_name)
47
- token_length = encoder.encode(text).length
59
+ def self.validate_max_tokens!(text, model_name)
60
+ text_token_length = token_length(text, model_name)
61
+ max_tokens = TOKEN_LIMITS[model_name] - text_token_length
48
62
 
49
- if token_length > TOKEN_LIMITS[model_name]
50
- raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{token_length} tokens long."
63
+ # Raise an error even if whole prompt is equal to the model's token limit (max_tokens == 0) since not response will be returned
64
+ if max_tokens <= 0
65
+ raise TokenLimitExceeded, "This model's maximum context length is #{TOKEN_LIMITS[model_name]} tokens, but the given text is #{text_token_length} tokens long."
51
66
  end
52
67
 
53
- true
68
+ max_tokens
69
+ end
70
+
71
+ #
72
+ # Calculate token length for a given text and model name
73
+ #
74
+ # @param text [String] The text to validate
75
+ # @param model_name [String] The model name to validate against
76
+ # @return [Integer] The token length of the text
77
+ #
78
+ def self.token_length(text, model_name)
79
+ encoder = Tiktoken.encoding_for_model(model_name)
80
+ encoder.encode(text).length
54
81
  end
55
82
  end
56
83
  end
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.5.0"
4
+ VERSION = "0.5.2"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: langchainrb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.0
4
+ version: 0.5.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrei Bondarev
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-06-05 00:00:00.000000000 Z
11
+ date: 2023-06-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: tiktoken_ruby