langchainrb 0.4.1 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/.env.example +2 -1
  3. data/.rubocop.yml +11 -0
  4. data/CHANGELOG.md +13 -0
  5. data/Gemfile +2 -0
  6. data/Gemfile.lock +14 -1
  7. data/README.md +42 -7
  8. data/Rakefile +5 -0
  9. data/examples/pdf_store_and_query_with_chroma.rb +1 -2
  10. data/examples/store_and_query_with_pinecone.rb +1 -2
  11. data/examples/store_and_query_with_qdrant.rb +1 -2
  12. data/examples/store_and_query_with_weaviate.rb +1 -2
  13. data/lefthook.yml +5 -0
  14. data/lib/langchain/agent/chain_of_thought_agent/chain_of_thought_agent.rb +6 -10
  15. data/lib/langchain/agent/sql_query_agent/sql_query_agent.rb +78 -0
  16. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.json +10 -0
  17. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.json +10 -0
  18. data/lib/langchain/dependency_helper.rb +34 -0
  19. data/lib/langchain/llm/ai21.rb +45 -0
  20. data/lib/langchain/llm/base.rb +2 -19
  21. data/lib/langchain/llm/cohere.rb +9 -0
  22. data/lib/langchain/llm/google_palm.rb +7 -0
  23. data/lib/langchain/llm/hugging_face.rb +9 -0
  24. data/lib/langchain/llm/openai.rb +33 -41
  25. data/lib/langchain/llm/replicate.rb +5 -2
  26. data/lib/langchain/processors/base.rb +2 -0
  27. data/lib/langchain/processors/xlsx.rb +27 -0
  28. data/lib/langchain/prompt/base.rb +8 -4
  29. data/lib/langchain/prompt/loading.rb +6 -1
  30. data/lib/langchain/prompt/prompt_template.rb +1 -1
  31. data/lib/langchain/tool/base.rb +4 -1
  32. data/lib/langchain/tool/calculator.rb +9 -0
  33. data/lib/langchain/tool/database.rb +45 -0
  34. data/lib/langchain/tool/ruby_code_interpreter.rb +6 -0
  35. data/lib/langchain/tool/serp_api.rb +5 -1
  36. data/lib/langchain/tool/wikipedia.rb +4 -0
  37. data/lib/langchain/vectorsearch/base.rb +8 -14
  38. data/lib/langchain/vectorsearch/chroma.rb +15 -7
  39. data/lib/langchain/vectorsearch/milvus.rb +13 -4
  40. data/lib/langchain/vectorsearch/pgvector.rb +15 -8
  41. data/lib/langchain/vectorsearch/pinecone.rb +15 -7
  42. data/lib/langchain/vectorsearch/qdrant.rb +15 -7
  43. data/lib/langchain/vectorsearch/weaviate.rb +15 -7
  44. data/lib/{version.rb → langchain/version.rb} +1 -1
  45. data/lib/langchain.rb +6 -2
  46. metadata +82 -4
  47. data/lib/dependency_helper.rb +0 -30
@@ -2,6 +2,15 @@
2
2
 
3
3
  module Langchain::LLM
4
4
  class OpenAI < Base
5
+ #
6
+ # Wrapper around OpenAI APIs.
7
+ #
8
+ # Gem requirements: gem "ruby-openai", "~> 4.0.0"
9
+ #
10
+ # Usage:
11
+ # openai = Langchain::LLM::OpenAI.new(api_key:, llm_options: {})
12
+ #
13
+
5
14
  DEFAULTS = {
6
15
  temperature: 0.0,
7
16
  completion_model_name: "text-davinci-003",
@@ -10,12 +19,11 @@ module Langchain::LLM
10
19
  dimension: 1536
11
20
  }.freeze
12
21
 
13
- def initialize(api_key:)
22
+ def initialize(api_key:, llm_options: {})
14
23
  depends_on "ruby-openai"
15
24
  require "openai"
16
25
 
17
- # TODO: Add support to pass `organization_id:`
18
- @client = ::OpenAI::Client.new(access_token: api_key)
26
+ @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
19
27
  end
20
28
 
21
29
  #
@@ -24,17 +32,12 @@ module Langchain::LLM
24
32
  # @param text [String] The text to generate an embedding for
25
33
  # @return [Array] The embedding
26
34
  #
27
- def embed(text:)
28
- model = DEFAULTS[:embeddings_model_name]
35
+ def embed(text:, **params)
36
+ parameters = {model: DEFAULTS[:embeddings_model_name], input: text}
29
37
 
30
- Langchain::Utils::TokenLengthValidator.validate!(text, model)
38
+ Langchain::Utils::TokenLengthValidator.validate!(text, parameters[:model])
31
39
 
32
- response = client.embeddings(
33
- parameters: {
34
- model: model,
35
- input: text
36
- }
37
- )
40
+ response = client.embeddings(parameters: parameters.merge(params))
38
41
  response.dig("data").first.dig("embedding")
39
42
  end
40
43
 
@@ -45,23 +48,13 @@ module Langchain::LLM
45
48
  # @return [String] The completion
46
49
  #
47
50
  def complete(prompt:, **params)
48
- model = DEFAULTS[:completion_model_name]
49
-
50
- Langchain::Utils::TokenLengthValidator.validate!(prompt, model)
51
-
52
- default_params = {
53
- model: model,
54
- temperature: DEFAULTS[:temperature],
55
- prompt: prompt
56
- }
51
+ parameters = compose_parameters DEFAULTS[:completion_model_name], params
57
52
 
58
- if params[:stop_sequences]
59
- default_params[:stop] = params.delete(:stop_sequences)
60
- end
53
+ Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
61
54
 
62
- default_params.merge!(params)
55
+ parameters[:prompt] = prompt
63
56
 
64
- response = client.completions(parameters: default_params)
57
+ response = client.completions(parameters: parameters)
65
58
  response.dig("choices", 0, "text")
66
59
  end
67
60
 
@@ -72,24 +65,13 @@ module Langchain::LLM
72
65
  # @return [String] The chat completion
73
66
  #
74
67
  def chat(prompt:, **params)
75
- model = DEFAULTS[:chat_completion_model_name]
68
+ parameters = compose_parameters DEFAULTS[:chat_completion_model_name], params
76
69
 
77
- Langchain::Utils::TokenLengthValidator.validate!(prompt, model)
70
+ Langchain::Utils::TokenLengthValidator.validate!(prompt, parameters[:model])
78
71
 
79
- default_params = {
80
- model: model,
81
- temperature: DEFAULTS[:temperature],
82
- # TODO: Figure out how to introduce persisted conversations
83
- messages: [{role: "user", content: prompt}]
84
- }
85
-
86
- if params[:stop_sequences]
87
- default_params[:stop] = params.delete(:stop_sequences)
88
- end
72
+ parameters[:messages] = [{role: "user", content: prompt}]
89
73
 
90
- default_params.merge!(params)
91
-
92
- response = client.chat(parameters: default_params)
74
+ response = client.chat(parameters: parameters)
93
75
  response.dig("choices", 0, "message", "content")
94
76
  end
95
77
 
@@ -112,5 +94,15 @@ module Langchain::LLM
112
94
  max_tokens: 2048
113
95
  )
114
96
  end
97
+
98
+ private
99
+
100
+ def compose_parameters(model, params)
101
+ default_params = {model: model, temperature: DEFAULTS[:temperature]}
102
+
103
+ default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
104
+
105
+ default_params.merge(params)
106
+ end
115
107
  end
116
108
  end
@@ -2,7 +2,11 @@
2
2
 
3
3
  module Langchain::LLM
4
4
  class Replicate < Base
5
+ #
5
6
  # Wrapper around Replicate.com LLM provider
7
+ #
8
+ # Gem requirements: gem "replicate-ruby", "~> 0.2.2"
9
+ #
6
10
  # Use it directly:
7
11
  # replicate = LLM::Replicate.new(api_key: ENV["REPLICATE_API_KEY"])
8
12
  #
@@ -10,8 +14,7 @@ module Langchain::LLM
10
14
  # chroma = Vectorsearch::Chroma.new(
11
15
  # url: ENV["CHROMA_URL"],
12
16
  # index_name: "...",
13
- # llm: :replicate,
14
- # llm_api_key: ENV["REPLICATE_API_KEY"],
17
+ # llm: Langchain::LLM::Replicate(api_key: ENV["REPLICATE_API_KEY"])
15
18
  # )
16
19
 
17
20
  DEFAULTS = {
@@ -3,6 +3,8 @@
3
3
  module Langchain
4
4
  module Processors
5
5
  class Base
6
+ include Langchain::DependencyHelper
7
+
6
8
  EXTENSIONS = []
7
9
  CONTENT_TYPES = []
8
10
 
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Processors
5
+ class Xlsx < Base
6
+ EXTENSIONS = [".xlsx", ".xlsm"].freeze
7
+ CONTENT_TYPES = ["application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"].freeze
8
+
9
+ def initialize(*)
10
+ depends_on "roo"
11
+ require "roo"
12
+ end
13
+
14
+ # Parse the document and return the text
15
+ # @param [File] data
16
+ # @return [Array<Array<String>>] Array of rows, each row is an array of cells
17
+ def parse(data)
18
+ xlsx_file = Roo::Spreadsheet.open(data)
19
+ xlsx_file.each_with_pagename.flat_map do |_, sheet|
20
+ sheet.map do |row|
21
+ row.map { |i| i.to_s.strip }
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -2,6 +2,7 @@
2
2
 
3
3
  require "strscan"
4
4
  require "json"
5
+ require "yaml"
5
6
 
6
7
  module Langchain::Prompt
7
8
  class Base
@@ -52,10 +53,13 @@ module Langchain::Prompt
52
53
  directory_path = save_path.dirname
53
54
  FileUtils.mkdir_p(directory_path) unless directory_path.directory?
54
55
 
55
- if save_path.extname == ".json"
56
+ case save_path.extname
57
+ when ".json"
56
58
  File.write(file_path, to_h.to_json)
59
+ when ".yaml", ".yml"
60
+ File.write(file_path, to_h.to_yaml)
57
61
  else
58
- raise ArgumentError, "#{file_path} must be json"
62
+ raise ArgumentError, "#{file_path} must be json or yaml file"
59
63
  end
60
64
  end
61
65
 
@@ -64,9 +68,9 @@ module Langchain::Prompt
64
68
  #
65
69
  # This method takes a template string and returns an array of input variable names
66
70
  # contained within the template. Input variables are defined as text enclosed in
67
- # curly braces (e.g. "{variable_name}").
71
+ # curly braces (e.g. <code>\{variable_name\}</code>).
68
72
  #
69
- # Content within two consecutive curly braces (e.g. "{{ignore_me}}) are ignored.
73
+ # Content within two consecutive curly braces (e.g. <code>\{\{ignore_me}}</code>) are ignored.
70
74
  #
71
75
  # @param template [String] The template string to extract variables from.
72
76
  #
@@ -2,6 +2,8 @@
2
2
 
3
3
  require "strscan"
4
4
  require "pathname"
5
+ require "json"
6
+ require "yaml"
5
7
 
6
8
  module Langchain::Prompt
7
9
  TYPE_TO_LOADER = {
@@ -22,8 +24,11 @@ module Langchain::Prompt
22
24
  def load_from_path(file_path:)
23
25
  file_path = file_path.is_a?(String) ? Pathname.new(file_path) : file_path
24
26
 
25
- if file_path.extname == ".json"
27
+ case file_path.extname
28
+ when ".json"
26
29
  config = JSON.parse(File.read(file_path))
30
+ when ".yaml", ".yml"
31
+ config = YAML.safe_load(File.read(file_path))
27
32
  else
28
33
  raise ArgumentError, "Got unsupported file type #{file_path.extname}"
29
34
  end
@@ -20,7 +20,7 @@ module Langchain::Prompt
20
20
  end
21
21
 
22
22
  #
23
- # Format the prompt with the inputs. Double {{}} replaced with single {} to adhere to f-string spec.
23
+ # Format the prompt with the inputs. Double <code>{{}}</code> replaced with single <code>{}</code> to adhere to f-string spec.
24
24
  #
25
25
  # @param kwargs [Hash] Any arguments to be passed to the prompt template.
26
26
  # @return [String] A formatted string.
@@ -2,6 +2,8 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class Base
5
+ include Langchain::DependencyHelper
6
+
5
7
  # How to add additional Tools?
6
8
  # 1. Create a new file in lib/tool/your_tool_name.rb
7
9
  # 2. Add your tool to the TOOLS hash below
@@ -12,7 +14,8 @@ module Langchain::Tool
12
14
  TOOLS = {
13
15
  "calculator" => "Langchain::Tool::Calculator",
14
16
  "search" => "Langchain::Tool::SerpApi",
15
- "wikipedia" => "Langchain::Tool::Wikipedia"
17
+ "wikipedia" => "Langchain::Tool::Wikipedia",
18
+ "database" => "Langchain::Tool::Database"
16
19
  }
17
20
 
18
21
  def self.description(value)
@@ -2,6 +2,15 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class Calculator < Base
5
+ #
6
+ # A calculator tool that falls back to the Google calculator widget
7
+ #
8
+ # Gem requirements:
9
+ # gem "eqn", "~> 1.6.5"
10
+ # gem "google_search_results", "~> 2.0.0"
11
+ # ENV requirements: ENV["SERPAPI_API_KEY"]
12
+ #
13
+
5
14
  description <<~DESC
6
15
  Useful for getting the result of a math expression.
7
16
 
@@ -0,0 +1,45 @@
1
+ module Langchain::Tool
2
+ class Database < Base
3
+ #
4
+ # Connects to a database, executes SQL queries, and outputs DB schema for Agents to use
5
+ #
6
+ # Gem requirements: gem "sequel", "~> 5.68.0"
7
+ #
8
+
9
+ description <<~DESC
10
+ Useful for getting the result of a database query.
11
+
12
+ The input to this tool should be valid SQL.
13
+ DESC
14
+
15
+ # Establish a database connection
16
+ # @param db_connection_string [String] Database connection info, e.g. 'postgres://user:password@localhost:5432/db_name'
17
+ def initialize(db_connection_string)
18
+ depends_on "sequel"
19
+ require "sequel"
20
+ require "sequel/extensions/schema_dumper"
21
+
22
+ raise StandardError, "db_connection_string parameter cannot be blank" if db_connection_string.empty?
23
+
24
+ @db = Sequel.connect(db_connection_string)
25
+ @db.extension :schema_dumper
26
+ end
27
+
28
+ def schema
29
+ Langchain.logger.info("[#{self.class.name}]".light_blue + ": Dumping schema")
30
+ @db.dump_schema_migration(same_db: true, indexes: false) unless @db.adapter_scheme == :mock
31
+ end
32
+
33
+ # Evaluates a sql expression
34
+ # @param input [String] sql expression
35
+ # @return [Array] results
36
+ def execute(input:)
37
+ Langchain.logger.info("[#{self.class.name}]".light_blue + ": Executing \"#{input}\"")
38
+ begin
39
+ @db[input].to_a
40
+ rescue Sequel::DatabaseError => e
41
+ Langchain.logger.error("[#{self.class.name}]".light_red + ": #{e.message}")
42
+ end
43
+ end
44
+ end
45
+ end
@@ -2,6 +2,12 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class RubyCodeInterpreter < Base
5
+ #
6
+ # A tool that execute Ruby code in a sandboxed environment.
7
+ #
8
+ # Gem requirements: gem "safe_ruby", "~> 1.0.4"
9
+ #
10
+
5
11
  description <<~DESC
6
12
  A Ruby code interpreter. Use this to execute ruby expressions. Input should be a valid ruby expression. If you want to see the output of the tool, make sure to return a value.
7
13
  DESC
@@ -2,8 +2,12 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class SerpApi < Base
5
+ #
5
6
  # Wrapper around SerpAPI
6
- # Set ENV["SERPAPI_API_KEY"] to use it
7
+ #
8
+ # Gem requirements: gem "google_search_results", "~> 2.0.0"
9
+ # ENV requirements: ENV["SERPAPI_API_KEY"] # https://serpapi.com/manage-api-key)
10
+ #
7
11
 
8
12
  description <<~DESC
9
13
  A wrapper around Google Search.
@@ -2,7 +2,11 @@
2
2
 
3
3
  module Langchain::Tool
4
4
  class Wikipedia < Base
5
+ #
5
6
  # Tool that adds the capability to search using the Wikipedia API
7
+ #
8
+ # Gem requirements: gem "wikipedia-client", "~> 1.17.0"
9
+ #
6
10
 
7
11
  description <<~DESC
8
12
  A wrapper around Wikipedia.
@@ -4,21 +4,16 @@ require "forwardable"
4
4
 
5
5
  module Langchain::Vectorsearch
6
6
  class Base
7
+ include Langchain::DependencyHelper
7
8
  extend Forwardable
8
9
 
9
- attr_reader :client, :index_name, :llm, :llm_api_key, :llm_client
10
+ attr_reader :client, :index_name, :llm
10
11
 
11
12
  DEFAULT_METRIC = "cosine"
12
13
 
13
- # @param llm [Symbol] The LLM to use
14
- # @param llm_api_key [String] The API key for the LLM
15
- def initialize(llm:, llm_api_key:)
16
- Langchain::LLM::Base.validate_llm!(llm: llm)
17
-
14
+ # @param llm [Object] The LLM client to use
15
+ def initialize(llm:)
18
16
  @llm = llm
19
- @llm_api_key = llm_api_key
20
-
21
- @llm_client = Langchain::LLM.const_get(Langchain::LLM::Base::LLMS.fetch(llm)).new(api_key: llm_api_key)
22
17
  end
23
18
 
24
19
  # Method supported by Vectorsearch DB to create a default schema
@@ -47,7 +42,7 @@ module Langchain::Vectorsearch
47
42
  raise NotImplementedError, "#{self.class.name} does not support asking questions"
48
43
  end
49
44
 
50
- def_delegators :llm_client,
45
+ def_delegators :llm,
51
46
  :default_dimension
52
47
 
53
48
  def generate_prompt(question:, context:)
@@ -68,11 +63,10 @@ module Langchain::Vectorsearch
68
63
  prompt_template.format(question: question)
69
64
  end
70
65
 
71
- def add_data(path: nil, paths: nil)
72
- raise ArgumentError, "Either path or paths must be provided" if path.nil? && paths.nil?
73
- raise ArgumentError, "Either path or paths must be provided, not both" if !path.nil? && !paths.nil?
66
+ def add_data(paths:)
67
+ raise ArgumentError, "Paths must be provided" if paths.to_a.empty?
74
68
 
75
- texts = Array(path || paths)
69
+ texts = Array(paths)
76
70
  .flatten
77
71
  .map { |path| Langchain::Loader.new(path)&.load&.value }
78
72
  .compact
@@ -2,13 +2,21 @@
2
2
 
3
3
  module Langchain::Vectorsearch
4
4
  class Chroma < Base
5
+ #
6
+ # Wrapper around Chroma DB
7
+ #
8
+ # Gem requirements: gem "chroma-db", "~> 0.3.0"
9
+ #
10
+ # Usage:
11
+ # chroma = Langchain::Vectorsearch::Chroma.new(url:, index_name:, llm:, llm_api_key:, api_key: nil)
12
+ #
13
+
5
14
  # Initialize the Chroma client
6
15
  # @param url [String] The URL of the Qdrant server
7
16
  # @param api_key [String] The API key to use
8
17
  # @param index_name [String] The name of the index to use
9
- # @param llm [Symbol] The LLM to use
10
- # @param llm_api_key [String] The API key for the LLM
11
- def initialize(url:, index_name:, llm:, llm_api_key:, api_key: nil)
18
+ # @param llm [Object] The LLM client to use
19
+ def initialize(url:, index_name:, llm:, api_key: nil)
12
20
  depends_on "chroma-db"
13
21
  require "chroma-db"
14
22
 
@@ -18,7 +26,7 @@ module Langchain::Vectorsearch
18
26
 
19
27
  @index_name = index_name
20
28
 
21
- super(llm: llm, llm_api_key: llm_api_key)
29
+ super(llm: llm)
22
30
  end
23
31
 
24
32
  # Add a list of texts to the index
@@ -29,7 +37,7 @@ module Langchain::Vectorsearch
29
37
  ::Chroma::Resources::Embedding.new(
30
38
  # TODO: Add support for passing your own IDs
31
39
  id: SecureRandom.uuid,
32
- embedding: llm_client.embed(text: text),
40
+ embedding: llm.embed(text: text),
33
41
  # TODO: Add support for passing metadata
34
42
  metadata: [], # metadatas[index],
35
43
  document: text # Do we actually need to store the whole original document?
@@ -54,7 +62,7 @@ module Langchain::Vectorsearch
54
62
  query:,
55
63
  k: 4
56
64
  )
57
- embedding = llm_client.embed(text: query)
65
+ embedding = llm.embed(text: query)
58
66
 
59
67
  similarity_search_by_vector(
60
68
  embedding: embedding,
@@ -92,7 +100,7 @@ module Langchain::Vectorsearch
92
100
 
93
101
  prompt = generate_prompt(question: question, context: context)
94
102
 
95
- llm_client.chat(prompt: prompt)
103
+ llm.chat(prompt: prompt)
96
104
  end
97
105
 
98
106
  private
@@ -2,14 +2,23 @@
2
2
 
3
3
  module Langchain::Vectorsearch
4
4
  class Milvus < Base
5
- def initialize(url:, index_name:, llm:, llm_api_key:, api_key: nil)
5
+ #
6
+ # Wrapper around Milvus REST APIs.
7
+ #
8
+ # Gem requirements: gem "milvus", "~> 0.9.0"
9
+ #
10
+ # Usage:
11
+ # milvus = Langchain::Vectorsearch::Milvus.new(url:, index_name:, llm:, llm_api_key:)
12
+ #
13
+
14
+ def initialize(url:, index_name:, llm:, api_key: nil)
6
15
  depends_on "milvus"
7
16
  require "milvus"
8
17
 
9
18
  @client = ::Milvus::Client.new(url: url)
10
19
  @index_name = index_name
11
20
 
12
- super(llm: llm, llm_api_key: llm_api_key)
21
+ super(llm: llm)
13
22
  end
14
23
 
15
24
  def add_texts(texts:)
@@ -24,7 +33,7 @@ module Langchain::Vectorsearch
24
33
  }, {
25
34
  field_name: "vectors",
26
35
  type: ::Milvus::DATA_TYPES["binary_vector"],
27
- field: Array(texts).map { |text| llm_client.embed(text: text) }
36
+ field: Array(texts).map { |text| llm.embed(text: text) }
28
37
  }
29
38
  ]
30
39
  )
@@ -69,7 +78,7 @@ module Langchain::Vectorsearch
69
78
  end
70
79
 
71
80
  def similarity_search(query:, k: 4)
72
- embedding = llm_client.embed(text: query)
81
+ embedding = llm.embed(text: query)
73
82
 
74
83
  similarity_search_by_vector(
75
84
  embedding: embedding,
@@ -1,8 +1,16 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain::Vectorsearch
4
- # The PostgreSQL vector search adapter
5
4
  class Pgvector < Base
5
+ #
6
+ # The PostgreSQL vector search adapter
7
+ #
8
+ # Gem requirements: gem "pgvector", "~> 0.2"
9
+ #
10
+ # Usage:
11
+ # pgvector = Langchain::Vectorsearch::Pgvector.new(url:, index_name:, llm:, llm_api_key:)
12
+ #
13
+
6
14
  # The operators supported by the PostgreSQL vector search adapter
7
15
  OPERATORS = {
8
16
  "cosine_distance" => "<=>",
@@ -14,10 +22,9 @@ module Langchain::Vectorsearch
14
22
 
15
23
  # @param url [String] The URL of the PostgreSQL database
16
24
  # @param index_name [String] The name of the table to use for the index
17
- # @param llm [String] The URL of the Language Layer API
18
- # @param llm_api_key [String] The API key for the Language Layer API
25
+ # @param llm [Object] The LLM client to use
19
26
  # @param api_key [String] The API key for the Vectorsearch DB (not used for PostgreSQL)
20
- def initialize(url:, index_name:, llm:, llm_api_key:, api_key: nil)
27
+ def initialize(url:, index_name:, llm:, api_key: nil)
21
28
  require "pg"
22
29
  require "pgvector"
23
30
 
@@ -30,7 +37,7 @@ module Langchain::Vectorsearch
30
37
  @quoted_table_name = @client.quote_ident(index_name)
31
38
  @operator = OPERATORS[DEFAULT_OPERATOR]
32
39
 
33
- super(llm: llm, llm_api_key: llm_api_key)
40
+ super(llm: llm)
34
41
  end
35
42
 
36
43
  # Add a list of texts to the index
@@ -38,7 +45,7 @@ module Langchain::Vectorsearch
38
45
  # @return [PG::Result] The response from the database
39
46
  def add_texts(texts:)
40
47
  data = texts.flat_map do |text|
41
- [text, llm_client.embed(text: text)]
48
+ [text, llm.embed(text: text)]
42
49
  end
43
50
  values = texts.length.times.map { |i| "($#{2 * i + 1}, $#{2 * i + 2})" }.join(",")
44
51
  client.exec_params(
@@ -67,7 +74,7 @@ module Langchain::Vectorsearch
67
74
  # @param k [Integer] The number of top results to return
68
75
  # @return [Array<Hash>] The results of the search
69
76
  def similarity_search(query:, k: 4)
70
- embedding = llm_client.embed(text: query)
77
+ embedding = llm.embed(text: query)
71
78
 
72
79
  similarity_search_by_vector(
73
80
  embedding: embedding,
@@ -105,7 +112,7 @@ module Langchain::Vectorsearch
105
112
 
106
113
  prompt = generate_prompt(question: question, context: context)
107
114
 
108
- llm_client.chat(prompt: prompt)
115
+ llm.chat(prompt: prompt)
109
116
  end
110
117
  end
111
118
  end
@@ -2,13 +2,21 @@
2
2
 
3
3
  module Langchain::Vectorsearch
4
4
  class Pinecone < Base
5
+ #
6
+ # Wrapper around Pinecone API.
7
+ #
8
+ # Gem requirements: gem "pinecone", "~> 0.1.6"
9
+ #
10
+ # Usage:
11
+ # pinecone = Langchain::Vectorsearch::Pinecone.new(environment:, api_key:, index_name:, llm:, llm_api_key:)
12
+ #
13
+
5
14
  # Initialize the Pinecone client
6
15
  # @param environment [String] The environment to use
7
16
  # @param api_key [String] The API key to use
8
17
  # @param index_name [String] The name of the index to use
9
- # @param llm [Symbol] The LLM to use
10
- # @param llm_api_key [String] The API key for the LLM
11
- def initialize(environment:, api_key:, index_name:, llm:, llm_api_key:)
18
+ # @param llm [Object] The LLM client to use
19
+ def initialize(environment:, api_key:, index_name:, llm:)
12
20
  depends_on "pinecone"
13
21
  require "pinecone"
14
22
 
@@ -20,7 +28,7 @@ module Langchain::Vectorsearch
20
28
  @client = ::Pinecone::Client.new
21
29
  @index_name = index_name
22
30
 
23
- super(llm: llm, llm_api_key: llm_api_key)
31
+ super(llm: llm)
24
32
  end
25
33
 
26
34
  # Add a list of texts to the index
@@ -34,7 +42,7 @@ module Langchain::Vectorsearch
34
42
  # TODO: Allows passing in your own IDs
35
43
  id: SecureRandom.uuid,
36
44
  metadata: metadata || {content: text},
37
- values: llm_client.embed(text: text)
45
+ values: llm.embed(text: text)
38
46
  }
39
47
  end
40
48
 
@@ -65,7 +73,7 @@ module Langchain::Vectorsearch
65
73
  namespace: "",
66
74
  filter: nil
67
75
  )
68
- embedding = llm_client.embed(text: query)
76
+ embedding = llm.embed(text: query)
69
77
 
70
78
  similarity_search_by_vector(
71
79
  embedding: embedding,
@@ -112,7 +120,7 @@ module Langchain::Vectorsearch
112
120
 
113
121
  prompt = generate_prompt(question: question, context: context)
114
122
 
115
- llm_client.chat(prompt: prompt)
123
+ llm.chat(prompt: prompt)
116
124
  end
117
125
  end
118
126
  end