langchainrb 0.4.1 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/.env.example +2 -1
  3. data/.rubocop.yml +11 -0
  4. data/CHANGELOG.md +13 -0
  5. data/Gemfile +2 -0
  6. data/Gemfile.lock +14 -1
  7. data/README.md +42 -7
  8. data/Rakefile +5 -0
  9. data/examples/pdf_store_and_query_with_chroma.rb +1 -2
  10. data/examples/store_and_query_with_pinecone.rb +1 -2
  11. data/examples/store_and_query_with_qdrant.rb +1 -2
  12. data/examples/store_and_query_with_weaviate.rb +1 -2
  13. data/lefthook.yml +5 -0
  14. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +6 -10
  15. data/lib/langchain/agent/sql_query_agent/sql_query_agent.rb +78 -0
  16. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json +10 -0
  17. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json +10 -0
  18. data/lib/langchain/dependency_helper.rb +34 -0
  19. data/lib/langchain/llm/ai21.rb +45 -0
  20. data/lib/langchain/llm/base.rb +2 -19
  21. data/lib/langchain/llm/cohere.rb +9 -0
  22. data/lib/langchain/llm/google_palm.rb +7 -0
  23. data/lib/langchain/llm/hugging_face.rb +9 -0
  24. data/lib/langchain/llm/openai.rb +33 -41
  25. data/lib/langchain/llm/replicate.rb +5 -2
  26. data/lib/langchain/processors/base.rb +2 -0
  27. data/lib/langchain/processors/xlsx.rb +27 -0
  28. data/lib/langchain/prompt/base.rb +8 -4
  29. data/lib/langchain/prompt/loading.rb +6 -1
  30. data/lib/langchain/prompt/prompt_template.rb +1 -1
  31. data/lib/langchain/tool/base.rb +4 -1
  32. data/lib/langchain/tool/calculator.rb +9 -0
  33. data/lib/langchain/tool/database.rb +45 -0
  34. data/lib/langchain/tool/ruby_code_interpreter.rb +6 -0
  35. data/lib/langchain/tool/serp_api.rb +5 -1
  36. data/lib/langchain/tool/wikipedia.rb +4 -0
  37. data/lib/langchain/vectorsearch/base.rb +8 -14
  38. data/lib/langchain/vectorsearch/chroma.rb +15 -7
  39. data/lib/langchain/vectorsearch/milvus.rb +13 -4
  40. data/lib/langchain/vectorsearch/pgvector.rb +15 -8
  41. data/lib/langchain/vectorsearch/pinecone.rb +15 -7
  42. data/lib/langchain/vectorsearch/qdrant.rb +15 -7
  43. data/lib/langchain/vectorsearch/weaviate.rb +15 -7
  44. data/lib/{version.rb → langchain/version.rb} +1 -1
  45. data/lib/langchain.rb +6 -2
  46. metadata +82 -4
  47. data/lib/dependency_helper.rb +0 -30
@@ -2,6 +2,15 @@
2
2
 
3
3
  module Langchain::LLM
4
4
  class OpenAI < Base
5
+ #
6
+ # Wrapper around OpenAI APIs.
7
+ #
8
+ # Gem requirements: gem "ruby-openai", "~> 4.0.0"
9
+ #
10
+ # Usage:
11
+ # openai = Langchain::LLM::OpenAI.new(api_key:, llm_options: {})
12
+ #
13
+
5
14
  DEFAULTS = {
6
15
  temperature: 0.0,
7
16
  completion_model_name: "text-davinci-003",
@@ -10,12 +19,11 @@ module Langchain::LLM
10
19
  dimension: 1536
11
20
  }.freeze
12
21
 
13
- def initialize(api_key:)
22
+ def initialize(api_key:, llm_options: {})
14
23
  depends_on "ruby-openai"
15
24
  require "openai"
16
25
 
17
- # TODO: Add support to pass `organization_id:`
18
- @client = ::OpenAI::Client.new(access_token: api_key)
26
+ @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
19
27
  end
20
28
 
21
29
  #
@@ -24,17 +32,12 @@ module Langchain::LLM
24
32
  # @param text [String] The text to generate an embedding for
25
33
  # @return [Array] The embedding
26
34
  #
27
- def embed(text:)
28
- model = DEFAULTS[:embeddings_model_name]
35
+ def embed(text:, **params)
36
+ parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
29
37
 
30
- Langchain::Utils::TokenLengthValidator.validate!(text, model)
38
+ Langchain::Utils::TokenLengthValidator.validate!(text, parameters[:model])
31
39
 
32
- response = client.embeddings(
33
- parameters: {
34
- model: model,
35
- input: text
36
- }
37
- )
40
+ response = client.embeddings(parameters: parameters.merge(params))
38
41
  response.dig("data").first.dig("embedding")
39
42
  end
40
43
 
@@ -45,23 +48,13 @@ module Langchain::LLM
45
48
  # @return [String] The completion
46
49
  #
47
50
  def complete(prompt:, **params)
48
- model = DEFAULTS[:completion_model_name]
49
-
50
- Langchain::Utils::TokenLengthValidator.validate!(prompt, model)
51
-
52
- default_params = {
53
- model: model,
54
- temperature: DEFAULTS[:temperature],
55
- prompt: prompt
56
- }
51
+ parameters = compose_parameters DEFAULTS[:completion_model_name], params
57
52
 
58
- if params[:stop_sequences]
59
- default_params[:stop] = params.delete(:stop_sequences)
60
- end
53
+ Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
61
54
 
62
- default_params.merge!(params)
55
+ parameters[:prompt] = prompt
63
56
 
64
- response = client.completions(parameters: default_params)
57
+ response = client.completions(parameters: parameters)
65
58
  response.dig("choices", 0, "text")
66
59
  end
67
60
 
@@ -72,24 +65,13 @@ module Langchain::LLM
72
65
  # @return [String] The chat completion
73
66
  #
74
67
  def chat(prompt:, **params)
75
- model = DEFAULTS[:chat_completion_model_name]
68
+ parameters = compose_parameters DEFAULTS[:chat_completion_model_name], params
76
69
 
77
- Langchain::Utils::TokenLengthValidator.validate!(prompt, model)
70
+ Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
78
71
 
79
- default_params = {
80
- model: model,
81
- temperature: DEFAULTS[:temperature],
82
- # TODO: Figure out how to introduce persisted conversations
83
- messages: [{role: "user", content: prompt}]
84
- }
85
-
86
- if params[:stop_sequences]
87
- default_params[:stop] = params.delete(:stop_sequences)
88
- end
72
+ parameters[:messages] = [{role: "user", content: prompt}]
89
73
 
90
- default_params.merge!(params)
91
-
92
- response = client.chat(parameters: default_params)
74
+ response = client.chat(parameters: parameters)
93
75
  response.dig("choices", 0, "message", "content")
94
76
  end
95
77
 
@@ -112,5 +94,15 @@ module Langchain::LLM
112
94
  max_tokens: 2048
113
95
  )
114
96
  end
97
+
98
+ private
99
+
100
+ def compose_parameters(model, params)
101
+ default_params = {model: model, temperature: DEFAULTS[:temperature]}
102
+
103
+ default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
104
+
105
+ default_params.merge(params)
106
+ end
115
107
  end
116
108
  end
@@ -2,7 +2,11 @@
2
2
 
3
3
  module Langchain::LLM
4
4
  class Replicate < Base
5
+ #
5
6
  # Wrapper around Replicate.com LLM provider
7
+ #
8
+ # Gem requirements: gem "replicate-ruby", "~> 0.2.2"
9
+ #
6
10
  # Use it directly:
7
11
  # replicate = LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
8
12
  #
@@ -10,8 +14,7 @@ module Langchain::LLM
10
14
  # chroma = Vectorsearch::Chroma.new(
11
15
  # url: ENV["CHROMA_URL"],
12
16
  # index_name: "...",
13
- # llm: :replicate,
14
- # llm_api_key: ENV["REPLICATE_API_KEY"],
17
+ # llm: Langchain::LLM::Replicate(api_key: ENV["REPLICATE_API_KEY"])
15
18
  # )
16
19
 
17
20
  DEFAULTS = {
@@ -3,6 +3,8 @@
3
3
  module Langchain
4
4
  module Processors
5
5
  class Base
6
+ include Langchain::DependencyHelper
7
+
6
8
  EXTENSIONS = []
7
9
  CONTENT_TYPES = []
8
10
 
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Processors
5
+ class Xlsx < Base
6
+ EXTENSIONS = [".xlsx", ".xlsm"].freeze
7
+ CONTENT_TYPES = ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"].freeze
8
+
9
+ def initialize(*)
10
+ depends_on "roo"
11
+ require "roo"
12
+ end
13
+
14
+ # Parse the document and return the text
15
+ # @param [File] data
16
+ # @return [Array<Array<String>>] Array of rows, each row is an array of cells
17
+ def parse(data)
18
+ xlsx_file = Roo::Spreadsheet.open(data)
19
+ xlsx_file.each_with_pagename.flat_map do |_, sheet|
20
+ sheet.map do |row|
21
+ row.map { |i| i.to_s.strip }
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -2,6 +2,7 @@
2
2
 
3
3
  require "strscan"
4
4
  require "json"
5
+ require "yaml"
5
6
 
6
7
  module Langchain::Prompt
7
8
  class Base
@@ -52,10 +53,13 @@ module Langchain::Prompt
52
53
  directory_path = save_path.dirname
53
54
  FileUtils.mkdir_p(directory_path) unless directory_path.directory?
54
55
 
55
- if save_path.extname == ".json"
56
+ case save_path.extname
57
+ when ".json"
56
58
  File.write(file_path, to_h.to_json)
59
+ when ".yaml", ".yml"
60
+ File.write(file_path, to_h.to_yaml)
57
61
  else
58
- raise ArgumentError, "#{file_path} must be json"
62
+ raise ArgumentError, "#{file_path} must be json or yaml file"
59
63
  end
60
64
  end
61
65
 
@@ -64,9 +68,9 @@ module Langchain::Prompt
64
68
  #
65
69
  # This method takes a template string and returns an array of input variable names
66
70
  # contained within the template. Input variables are defined as text enclosed in
67
- # curly braces (e.g. "{variable_name}").
71
+ # curly braces (e.g. <code>\{variable_name\}</code>).
68
72
  #
69
- # Content within two consecutive curly braces (e.g. "{{ignore_me}}) are ignored.
73
+ # Content within two consecutive curly braces (e.g. <code>\{\{ignore_me}}</code>) are ignored.
70
74
  #
71
75
  # @param template [String] The template string to extract variables from.
72
76
  #
@@ -2,6 +2,8 @@
2
2
 
3
3
  require "strscan"
4
4
  require "pathname"
5
+ require "json"
6
+ require "yaml"
5
7
 
6
8
  module Langchain::Prompt
7
9
  TYPE_TO_LOADER = {
@@ -22,8 +24,11 @@ module Langchain::Prompt
22
24
  def load_from_path(file_path:)
23
25
  file_path = file_path.is_a?(String) ? Pathname.new(file_path) : file_path
24
26
 
25
- if file_path.extname == ".json"
27
+ case file_path.extname
28
+ when ".json"
26
29
  config = JSON.parse(File.read(file_path))
30
+ when ".yaml", ".yml"
31
+ config = YAML.safe_load(File.read(file_path))
27
32
  else
28
33
  raise ArgumentError, "Got unsupported file type #{file_path.extname}"
29
34
  end
@@ -20,7 +20,7 @@ module Langchain::Prompt
20
20
  end
21
21
 
22
22
  #
23
- # Format the prompt with the inputs. Double {{}} replaced with single {} to adhere to f-string spec.
23
+ # Format the prompt with the inputs. Double <code>{{}}</code> replaced with single <code>{}</code> to adhere to f-string spec.
24
24
  #
25
25
  # @param kwargs [Hash] Any arguments to be passed to the prompt template.
26
26
  # @return [String] A formatted string.
@@ -2,6 +2,8 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class Base
5
+ include Langchain::DependencyHelper
6
+
5
7
  # How to add additional Tools?
6
8
  # 1. Create a new file in lib/tool/your_tool_name.rb
7
9
  # 2. Add your tool to the TOOLS hash below
@@ -12,7 +14,8 @@ module Langchain::Tool
12
14
  TOOLS = {
13
15
  "calculator" => "Langchain::Tool::Calculator",
14
16
  "search" => "Langchain::Tool::SerpApi",
15
- "wikipedia" => "Langchain::Tool::Wikipedia"
17
+ "wikipedia" => "Langchain::Tool::Wikipedia",
18
+ "database" => "Langchain::Tool::Database"
16
19
  }
17
20
 
18
21
  def self.description(value)
@@ -2,6 +2,15 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class Calculator < Base
5
+ #
6
+ # A calculator tool that falls back to the Google calculator widget
7
+ #
8
+ # Gem requirements:
9
+ # gem "eqn", "~> 1.6.5"
10
+ # gem "google_search_results", "~> 2.0.0"
11
+ # ENV requirements: ENV["SERPAPI_API_KEY"]
12
+ #
13
+
5
14
  description <<~DESC
6
15
  Useful for getting the result of a math expression.
7
16
 
@@ -0,0 +1,45 @@
1
+ module Langchain::Tool
2
+ class Database < Base
3
+ #
4
+ # Connects to a database, executes SQL queries, and outputs DB schema for Agents to use
5
+ #
6
+ # Gem requirements: gem "sequel", "~> 5.68.0"
7
+ #
8
+
9
+ description <<~DESC
10
+ Useful for getting the result of a database query.
11
+
12
+ The input to this tool should be valid SQL.
13
+ DESC
14
+
15
+ # Establish a database connection
16
+ # @param db_connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
17
+ def initialize(db_connection_string)
18
+ depends_on "sequel"
19
+ require "sequel"
20
+ require "sequel/extensions/schema_dumper"
21
+
22
+ raise StandardError, "db_connection_string parameter cannot be blank" if db_connection_string.empty?
23
+
24
+ @db = Sequel.connect(db_connection_string)
25
+ @db.extension :schema_dumper
26
+ end
27
+
28
+ def schema
29
+ Langchain.logger.info("[#{self.class.name}]".light_blue + ": Dumping schema")
30
+ @db.dump_schema_migration(same_db: true, indexes: false) unless @db.adapter_scheme == :mock
31
+ end
32
+
33
+ # Evaluates a sql expression
34
+ # @param input [String] sql expression
35
+ # @return [Array] results
36
+ def execute(input:)
37
+ Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
38
+ begin
39
+ @db[input].to_a
40
+ rescue Sequel::DatabaseError => e
41
+ Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
42
+ end
43
+ end
44
+ end
45
+ end
@@ -2,6 +2,12 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class RubyCodeInterpreter < Base
5
+ #
6
+ # A tool that execute Ruby code in a sandboxed environment.
7
+ #
8
+ # Gem requirements: gem "safe_ruby", "~> 1.0.4"
9
+ #
10
+
5
11
  description <<~DESC
6
12
  A Ruby code interpreter. Use this to execute ruby expressions. Input should be a valid ruby expression. If you want to see the output of the tool, make sure to return a value.
7
13
  DESC
@@ -2,8 +2,12 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class SerpApi < Base
5
+ #
5
6
  # Wrapper around SerpAPI
6
- # Set ENV["SERPAPI_API_KEY"] to use it
7
+ #
8
+ # Gem requirements: gem "google_search_results", "~> 2.0.0"
9
+ # ENV requirements: ENV["SERPAPI_API_KEY"] # https://serpapi.com/manage-api-key)
10
+ #
7
11
 
8
12
  description <<~DESC
9
13
  A wrapper around Google Search.
@@ -2,7 +2,11 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class Wikipedia < Base
5
+ #
5
6
  # Tool that adds the capability to search using the Wikipedia API
7
+ #
8
+ # Gem requirements: gem "wikipedia-client", "~> 1.17.0"
9
+ #
6
10
 
7
11
  description <<~DESC
8
12
  A wrapper around Wikipedia.
@@ -4,21 +4,16 @@ require "forwardable"
4
4
 
5
5
  module Langchain::Vectorsearch
6
6
  class Base
7
+ include Langchain::DependencyHelper
7
8
  extend Forwardable
8
9
 
9
- attr_reader :client, :index_name, :llm, :llm_api_key, :llm_client
10
+ attr_reader :client, :index_name, :llm
10
11
 
11
12
  DEFAULT_METRIC = "cosine"
12
13
 
13
- # @param llm [Symbol] The LLM to use
14
- # @param llm_api_key [String] The API key for the LLM
15
- def initialize(llm:, llm_api_key:)
16
- Langchain::LLM::Base.validate_llm!(llm: llm)
17
-
14
+ # @param llm [Object] The LLM client to use
15
+ def initialize(llm:)
18
16
  @llm = llm
19
- @llm_api_key = llm_api_key
20
-
21
- @llm_client = Langchain::LLM.const_get(Langchain::LLM::Base::LLMS.fetch(llm)).new(api_key: llm_api_key)
22
17
  end
23
18
 
24
19
  # Method supported by Vectorsearch DB to create a default schema
@@ -47,7 +42,7 @@ module Langchain::Vectorsearch
47
42
  raise NotImplementedError, "#{self.class.name} does not support asking questions"
48
43
  end
49
44
 
50
- def_delegators :llm_client,
45
+ def_delegators :llm,
51
46
  :default_dimension
52
47
 
53
48
  def generate_prompt(question:, context:)
@@ -68,11 +63,10 @@ module Langchain::Vectorsearch
68
63
  prompt_template.format(question: question)
69
64
  end
70
65
 
71
- def add_data(path: nil, paths: nil)
72
- raise ArgumentError, "Either path or paths must be provided" if path.nil? && paths.nil?
73
- raise ArgumentError, "Either path or paths must be provided, not both" if !path.nil? && !paths.nil?
66
+ def add_data(paths:)
67
+ raise ArgumentError, "Paths must be provided" if paths.to_a.empty?
74
68
 
75
- texts = Array(path || paths)
69
+ texts = Array(paths)
76
70
  .flatten
77
71
  .map { |path| Langchain::Loader.new(path)&.load&.value }
78
72
  .compact
@@ -2,13 +2,21 @@
2
2
 
3
3
  module Langchain::Vectorsearch
4
4
  class Chroma < Base
5
+ #
6
+ # Wrapper around Chroma DB
7
+ #
8
+ # Gem requirements: gem "chroma-db", "~> 0.3.0"
9
+ #
10
+ # Usage:
11
+ # chroma = Langchain::Vectorsearch::Chroma.new(url:, index_name:, llm:, llm_api_key:, api_key: nil)
12
+ #
13
+
5
14
  # Initialize the Chroma client
6
15
  # @param url [String] The URL of the Qdrant server
7
16
  # @param api_key [String] The API key to use
8
17
  # @param index_name [String] The name of the index to use
9
- # @param llm [Symbol] The LLM to use
10
- # @param llm_api_key [String] The API key for the LLM
11
- def initialize(url:, index_name:, llm:, llm_api_key:, api_key: nil)
18
+ # @param llm [Object] The LLM client to use
19
+ def initialize(url:, index_name:, llm:, api_key: nil)
12
20
  depends_on "chroma-db"
13
21
  require "chroma-db"
14
22
 
@@ -18,7 +26,7 @@ module Langchain::Vectorsearch
18
26
 
19
27
  @index_name = index_name
20
28
 
21
- super(llm: llm, llm_api_key: llm_api_key)
29
+ super(llm: llm)
22
30
  end
23
31
 
24
32
  # Add a list of texts to the index
@@ -29,7 +37,7 @@ module Langchain::Vectorsearch
29
37
  ::Chroma::Resources::Embedding.new(
30
38
  # TODO: Add support for passing your own IDs
31
39
  id: SecureRandom.uuid,
32
- embedding: llm_client.embed(text: text),
40
+ embedding: llm.embed(text: text),
33
41
  # TODO: Add support for passing metadata
34
42
  metadata: [], # metadatas[index],
35
43
  document: text # Do we actually need to store the whole original document?
@@ -54,7 +62,7 @@ module Langchain::Vectorsearch
54
62
  query:,
55
63
  k: 4
56
64
  )
57
- embedding = llm_client.embed(text: query)
65
+ embedding = llm.embed(text: query)
58
66
 
59
67
  similarity_search_by_vector(
60
68
  embedding: embedding,
@@ -92,7 +100,7 @@ module Langchain::Vectorsearch
92
100
 
93
101
  prompt = generate_prompt(question: question, context: context)
94
102
 
95
- llm_client.chat(prompt: prompt)
103
+ llm.chat(prompt: prompt)
96
104
  end
97
105
 
98
106
  private
@@ -2,14 +2,23 @@
2
2
 
3
3
  module Langchain::Vectorsearch
4
4
  class Milvus < Base
5
- def initialize(url:, index_name:, llm:, llm_api_key:, api_key: nil)
5
+ #
6
+ # Wrapper around Milvus REST APIs.
7
+ #
8
+ # Gem requirements: gem "milvus", "~> 0.9.0"
9
+ #
10
+ # Usage:
11
+ # milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, llm_api_key:)
12
+ #
13
+
14
+ def initialize(url:, index_name:, llm:, api_key: nil)
6
15
  depends_on "milvus"
7
16
  require "milvus"
8
17
 
9
18
  @client = ::Milvus::Client.new(url: url)
10
19
  @index_name = index_name
11
20
 
12
- super(llm: llm, llm_api_key: llm_api_key)
21
+ super(llm: llm)
13
22
  end
14
23
 
15
24
  def add_texts(texts:)
@@ -24,7 +33,7 @@ module Langchain::Vectorsearch
24
33
  }, {
25
34
  field_name: "vectors",
26
35
  type: ::Milvus::DATA_TYPES["binary_vector"],
27
- field: Array(texts).map { |text| llm_client.embed(text: text) }
36
+ field: Array(texts).map { |text| llm.embed(text: text) }
28
37
  }
29
38
  ]
30
39
  )
@@ -69,7 +78,7 @@ module Langchain::Vectorsearch
69
78
  end
70
79
 
71
80
  def similarity_search(query:, k: 4)
72
- embedding = llm_client.embed(text: query)
81
+ embedding = llm.embed(text: query)
73
82
 
74
83
  similarity_search_by_vector(
75
84
  embedding: embedding,
@@ -1,8 +1,16 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain::Vectorsearch
4
- # The PostgreSQL vector search adapter
5
4
  class Pgvector < Base
5
+ #
6
+ # The PostgreSQL vector search adapter
7
+ #
8
+ # Gem requirements: gem "pgvector", "~> 0.2"
9
+ #
10
+ # Usage:
11
+ # pgvector = Langchain::Vectorsearch::Pgvector.new(url:, index_name:, llm:, llm_api_key:)
12
+ #
13
+
6
14
  # The operators supported by the PostgreSQL vector search adapter
7
15
  OPERATORS = {
8
16
  "cosine_distance" => "<=>",
@@ -14,10 +22,9 @@ module Langchain::Vectorsearch
14
22
 
15
23
  # @param url [String] The URL of the PostgreSQL database
16
24
  # @param index_name [String] The name of the table to use for the index
17
- # @param llm [String] The URL of the Language Layer API
18
- # @param llm_api_key [String] The API key for the Language Layer API
25
+ # @param llm [Object] The LLM client to use
19
26
  # @param api_key [String] The API key for the Vectorsearch DB (not used for PostgreSQL)
20
- def initialize(url:, index_name:, llm:, llm_api_key:, api_key: nil)
27
+ def initialize(url:, index_name:, llm:, api_key: nil)
21
28
  require "pg"
22
29
  require "pgvector"
23
30
 
@@ -30,7 +37,7 @@ module Langchain::Vectorsearch
30
37
  @quoted_table_name = @client.quote_ident(index_name)
31
38
  @operator = OPERATORS[DEFAULT_OPERATOR]
32
39
 
33
- super(llm: llm, llm_api_key: llm_api_key)
40
+ super(llm: llm)
34
41
  end
35
42
 
36
43
  # Add a list of texts to the index
@@ -38,7 +45,7 @@ module Langchain::Vectorsearch
38
45
  # @return [PG::Result] The response from the database
39
46
  def add_texts(texts:)
40
47
  data = texts.flat_map do |text|
41
- [text, llm_client.embed(text: text)]
48
+ [text, llm.embed(text: text)]
42
49
  end
43
50
  values = texts.length.times.map { |i| "($#{2 * i + 1}, $#{2 * i + 2})" }.join(",")
44
51
  client.exec_params(
@@ -67,7 +74,7 @@ module Langchain::Vectorsearch
67
74
  # @param k [Integer] The number of top results to return
68
75
  # @return [Array<Hash>] The results of the search
69
76
  def similarity_search(query:, k: 4)
70
- embedding = llm_client.embed(text: query)
77
+ embedding = llm.embed(text: query)
71
78
 
72
79
  similarity_search_by_vector(
73
80
  embedding: embedding,
@@ -105,7 +112,7 @@ module Langchain::Vectorsearch
105
112
 
106
113
  prompt = generate_prompt(question: question, context: context)
107
114
 
108
- llm_client.chat(prompt: prompt)
115
+ llm.chat(prompt: prompt)
109
116
  end
110
117
  end
111
118
  end
@@ -2,13 +2,21 @@
2
2
 
3
3
  module Langchain::Vectorsearch
4
4
  class Pinecone < Base
5
+ #
6
+ # Wrapper around Pinecone API.
7
+ #
8
+ # Gem requirements: gem "pinecone", "~> 0.1.6"
9
+ #
10
+ # Usage:
11
+ # pinecone = Langchain::Vectorsearch::Pinecone.new(environment:, api_key:, index_name:, llm:, llm_api_key:)
12
+ #
13
+
5
14
  # Initialize the Pinecone client
6
15
  # @param environment [String] The environment to use
7
16
  # @param api_key [String] The API key to use
8
17
  # @param index_name [String] The name of the index to use
9
- # @param llm [Symbol] The LLM to use
10
- # @param llm_api_key [String] The API key for the LLM
11
- def initialize(environment:, api_key:, index_name:, llm:, llm_api_key:)
18
+ # @param llm [Object] The LLM client to use
19
+ def initialize(environment:, api_key:, index_name:, llm:)
12
20
  depends_on "pinecone"
13
21
  require "pinecone"
14
22
 
@@ -20,7 +28,7 @@ module Langchain::Vectorsearch
20
28
  @client = ::Pinecone::Client.new
21
29
  @index_name = index_name
22
30
 
23
- super(llm: llm, llm_api_key: llm_api_key)
31
+ super(llm: llm)
24
32
  end
25
33
 
26
34
  # Add a list of texts to the index
@@ -34,7 +42,7 @@ module Langchain::Vectorsearch
34
42
  # TODO: Allows passing in your own IDs
35
43
  id: SecureRandom.uuid,
36
44
  metadata: metadata || {content: text},
37
- values: llm_client.embed(text: text)
45
+ values: llm.embed(text: text)
38
46
  }
39
47
  end
40
48
 
@@ -65,7 +73,7 @@ module Langchain::Vectorsearch
65
73
  namespace: "",
66
74
  filter: nil
67
75
  )
68
- embedding = llm_client.embed(text: query)
76
+ embedding = llm.embed(text: query)
69
77
 
70
78
  similarity_search_by_vector(
71
79
  embedding: embedding,
@@ -112,7 +120,7 @@ module Langchain::Vectorsearch
112
120
 
113
121
  prompt = generate_prompt(question: question, context: context)
114
122
 
115
- llm_client.chat(prompt: prompt)
123
+ llm.chat(prompt: prompt)
116
124
  end
117
125
  end
118
126
  end