langchainrb 0.6.17 → 0.6.19

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +15 -0
  3. data/README.md +18 -3
  4. data/lib/langchain/active_record/hooks.rb +14 -0
  5. data/lib/langchain/agent/react_agent.rb +1 -1
  6. data/lib/langchain/agent/sql_query_agent.rb +2 -2
  7. data/lib/langchain/chunk.rb +16 -0
  8. data/lib/langchain/chunker/base.rb +4 -0
  9. data/lib/langchain/chunker/recursive_text.rb +5 -2
  10. data/lib/langchain/chunker/semantic.rb +4 -1
  11. data/lib/langchain/chunker/sentence.rb +4 -2
  12. data/lib/langchain/chunker/text.rb +5 -2
  13. data/lib/langchain/conversation.rb +1 -1
  14. data/lib/langchain/llm/ai21.rb +4 -3
  15. data/lib/langchain/llm/anthropic.rb +3 -3
  16. data/lib/langchain/llm/cohere.rb +6 -5
  17. data/lib/langchain/llm/google_palm.rb +14 -10
  18. data/lib/langchain/llm/hugging_face.rb +4 -3
  19. data/lib/langchain/llm/llama_cpp.rb +1 -1
  20. data/lib/langchain/llm/ollama.rb +18 -6
  21. data/lib/langchain/llm/openai.rb +7 -6
  22. data/lib/langchain/llm/replicate.rb +6 -10
  23. data/lib/langchain/llm/response/ai21_response.rb +13 -0
  24. data/lib/langchain/llm/response/anthropic_response.rb +29 -0
  25. data/lib/langchain/llm/response/base_response.rb +79 -0
  26. data/lib/langchain/llm/response/cohere_response.rb +21 -0
  27. data/lib/langchain/llm/response/google_palm_response.rb +36 -0
  28. data/lib/langchain/llm/response/hugging_face_response.rb +13 -0
  29. data/lib/langchain/llm/response/ollama_response.rb +26 -0
  30. data/lib/langchain/llm/response/openai_response.rb +51 -0
  31. data/lib/langchain/llm/response/replicate_response.rb +28 -0
  32. data/lib/langchain/vectorsearch/base.rb +4 -7
  33. data/lib/langchain/vectorsearch/chroma.rb +13 -12
  34. data/lib/langchain/vectorsearch/elasticsearch.rb +147 -0
  35. data/lib/langchain/vectorsearch/hnswlib.rb +5 -5
  36. data/lib/langchain/vectorsearch/milvus.rb +5 -4
  37. data/lib/langchain/vectorsearch/pgvector.rb +12 -6
  38. data/lib/langchain/vectorsearch/pinecone.rb +14 -13
  39. data/lib/langchain/vectorsearch/qdrant.rb +9 -8
  40. data/lib/langchain/vectorsearch/weaviate.rb +9 -8
  41. data/lib/langchain/version.rb +1 -1
  42. data/lib/langchain.rb +5 -0
  43. metadata +27 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3b9bca59bfb5909f6ac24ebf6dba6074f5faf3d2cdadab1a3b3a8a0f75f98adc
4
- data.tar.gz: a202726d383d2dc691cb4146e9b36cb7ea6f8ac35382a3df67f6e11d35b3562e
3
+ metadata.gz: d7be5e031274fba7a4c0d7fc2cd3f472ed83fb66d8c6b355fb71fbf69a825b73
4
+ data.tar.gz: 745cbc4f3d7b569d2e1407acc8be123f77a0aac2964840d7c3dca215592811ee
5
5
  SHA512:
6
- metadata.gz: b4eaf631f22236035c9e29b3618a70d14487cc9e39b6885e44497ebad2a98670ce88997fdb25144b6467e0caa69a04ce7e625c9e10bc88322131181c2254a570
7
- data.tar.gz: 981199fe2a0123e46ac3af54946c03d5eaa827473eae02f2e60accd0c680a0bbd40741800e05b79e890038523a1b910502a6cf4ed1f4ebf77845f4b2a2dbc5d9
6
+ metadata.gz: e1392abe2fc0c4928593bd77d0e62688e3959ec39fd3f7bb5effc784b47599402c611ecc545868178b5d04ec688d68d6406f220697e8bfe40771cc593292a192
7
+ data.tar.gz: 926bccf20c71af3d31d942cf439336df9edc489a8e5e0359a6c24bb26e5b818be048a7ef63ebcce721bb99392b49e407288ffdd7387dd33d3f0161e92ff6e045
data/CHANGELOG.md CHANGED
@@ -1,5 +1,20 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.6.19] - 2023-10-18
4
+ - Elasticsearch vector search support
5
+ - Fix `lib/langchain/railtie.rb` not being loaded with the gem
6
+
7
+ ## [0.6.18] - 2023-10-16
8
+ - Introduce `Langchain::LLM::Response`` object
9
+ - Introduce `Langchain::Chunk` object
10
+ - Add the ask() method to the Langchain::ActiveRecord::Hooks
11
+
12
+ ## [0.6.17] - 2023-10-10
13
+ - Bump weaviate and chroma-db deps
14
+ - `Langchain::Chunker::Semantic` chunker
15
+ - Re-structure Conversations class
16
+ - Bug fixes
17
+
3
18
  ## [0.6.16] - 2023-10-02
4
19
  - HyDE-style similarity search
5
20
  - `Langchain::Chunker::Sentence` chunker
data/README.md CHANGED
@@ -19,11 +19,11 @@ Langchain.rb is a library that's an abstraction layer on top many emergent AI, M
19
19
 
20
20
  Install the gem and add to the application's Gemfile by executing:
21
21
 
22
- $ bundle add langchainrb
22
+ bundle add langchainrb
23
23
 
24
24
  If bundler is not being used to manage dependencies, install the gem by executing:
25
25
 
26
- $ gem install langchainrb
26
+ gem install langchainrb
27
27
 
28
28
  ## Usage
29
29
 
@@ -37,7 +37,7 @@ require "langchain"
37
37
  | -------- |:------------------:| -------:| -----------------:| -------:| -----------------:|
38
38
  | [Chroma](https://trychroma.com/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
39
39
  | [Hnswlib](https://github.com/nmslib/hnswlib/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
40
- | [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
40
+ | [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
41
41
  | [Pinecone](https://www.pinecone.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
42
42
  | [Pgvector](https://github.com/pgvector/pgvector) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
43
43
  | [Qdrant](https://qdrant.tech/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
@@ -128,6 +128,21 @@ class Product < ActiveRecord::Base
128
128
  end
129
129
  ```
130
130
 
131
+ ### Exposed ActiveRecord methods
132
+ ```ruby
133
+ # Retrieve similar products based on the query string passed in
134
+ Product.similarity_search(
135
+ query:,
136
+ k: # number of results to be retrieved
137
+ )
138
+ ```
139
+ ```ruby
140
+ # Q&A-style querying based on the question passed in
141
+ Product.ask(
142
+ question:
143
+ )
144
+ ```
145
+
131
146
  Additional info [here](https://github.com/andreibondarev/langchainrb/blob/main/lib/langchain/active_record/hooks.rb#L10-L38).
132
147
 
133
148
  ### Using Standalone LLMs 🗣️
@@ -92,6 +92,20 @@ module Langchain
92
92
  ids = records.map { |record| record.dig("id") || record.dig("__id") }
93
93
  where(id: ids)
94
94
  end
95
+
96
+ # Ask a question and return the answer
97
+ #
98
+ # @param question [String] The question to ask
99
+ # @param k [Integer] The number of results to have in context
100
+ # @yield [String] Stream responses back one String at a time
101
+ # @return [String] The answer to the question
102
+ def ask(question:, k: 4, &block)
103
+ class_variable_get(:@@provider).ask(
104
+ question: question,
105
+ k: k,
106
+ &block
107
+ )
108
+ end
95
109
  end
96
110
  end
97
111
  end
@@ -58,7 +58,7 @@ module Langchain::Agent
58
58
  max_iterations.times do
59
59
  Langchain.logger.info("Sending the prompt to the #{llm.class} LLM", for: self.class)
60
60
 
61
- response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
61
+ response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"]).completion
62
62
 
63
63
  # Append the response to the prompt
64
64
  prompt += response
@@ -27,7 +27,7 @@ module Langchain::Agent
27
27
 
28
28
  # Get the SQL string to execute
29
29
  Langchain.logger.info("Passing the inital prompt to the #{llm.class} LLM", for: self.class)
30
- sql_string = llm.complete(prompt: prompt)
30
+ sql_string = llm.complete(prompt: prompt).completion
31
31
 
32
32
  # Execute the SQL string and collect the results
33
33
  Langchain.logger.info("Passing the SQL to the Database: #{sql_string}", for: self.class)
@@ -36,7 +36,7 @@ module Langchain::Agent
36
36
  # Pass the results and get the LLM to synthesize the answer to the question
37
37
  Langchain.logger.info("Passing the synthesize prompt to the #{llm.class} LLM with results: #{results}", for: self.class)
38
38
  prompt2 = create_prompt_for_answer(question: question, sql_query: sql_string, results: results)
39
- llm.complete(prompt: prompt2)
39
+ llm.complete(prompt: prompt2).completion
40
40
  end
41
41
 
42
42
  private
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Chunk
5
+ # The chunking process is the process of splitting a document into smaller chunks and creating instances of Langchain::Chunk
6
+
7
+ attr_reader :text
8
+
9
+ # Initialize a new chunk
10
+ # @param [String] text
11
+ # @return [Langchain::Chunk]
12
+ def initialize(text:)
13
+ @text = text
14
+ end
15
+ end
16
+ end
@@ -13,6 +13,10 @@ module Langchain
13
13
  # - {Langchain::Chunker::Semantic}
14
14
  # - {Langchain::Chunker::Sentence}
15
15
  class Base
16
+ # @return [Array<Langchain::Chunk>]
17
+ def chunks
18
+ raise NotImplementedError
19
+ end
16
20
  end
17
21
  end
18
22
  end
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separators = separators
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::RecursiveCharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separators: separators
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -23,7 +23,7 @@ module Langchain
23
23
  @prompt_template = prompt_template || default_prompt_template
24
24
  end
25
25
 
26
- # @return [Array<String>]
26
+ # @return [Array<Langchain::Chunk>]
27
27
  def chunks
28
28
  prompt = prompt_template.format(text: text)
29
29
 
@@ -34,6 +34,9 @@ module Langchain
34
34
  .split("---")
35
35
  .map(&:strip)
36
36
  .reject(&:empty?)
37
+ .map do |chunk|
38
+ Langchain::Chunk.new(text: chunk)
39
+ end
37
40
  end
38
41
 
39
42
  private
@@ -19,10 +19,12 @@ module Langchain
19
19
  @text = text
20
20
  end
21
21
 
22
- # @return [Array<String>]
22
+ # @return [Array<Langchain::Chunk>]
23
23
  def chunks
24
24
  ps = PragmaticSegmenter::Segmenter.new(text: text)
25
- ps.segment
25
+ ps.segment.map do |chunk|
26
+ Langchain::Chunk.new(text: chunk)
27
+ end
26
28
  end
27
29
  end
28
30
  end
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separator = separator
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::CharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separator: separator
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -58,7 +58,7 @@ module Langchain
58
58
  # @return [Response] The response from the model
59
59
  def message(message)
60
60
  @memory.append_message ::Langchain::Conversation::Prompt.new(message)
61
- ai_message = ::Langchain::Conversation::Response.new(llm_response)
61
+ ai_message = ::Langchain::Conversation::Response.new(llm_response.chat_completion)
62
62
  @memory.append_message(ai_message)
63
63
  ai_message
64
64
  end
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "ai21", "~> 0.2.1"
9
9
  #
10
10
  # Usage:
11
- # ai21 = Langchain::LLM::AI21.new(api_key:)
11
+ # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
12
12
  #
13
13
  class AI21 < Base
14
14
  DEFAULTS = {
@@ -30,7 +30,7 @@ module Langchain::LLM
30
30
  #
31
31
  # @param prompt [String] The prompt to generate a completion for
32
32
  # @param params [Hash] The parameters to pass to the API
33
- # @return [String] The completion
33
+ # @return [Langchain::LLM::AI21Response] The completion
34
34
  #
35
35
  def complete(prompt:, **params)
36
36
  parameters = complete_parameters params
@@ -38,7 +38,7 @@ module Langchain::LLM
38
38
  parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], client)
39
39
 
40
40
  response = client.complete(prompt, parameters)
41
- response.dig(:completions, 0, :data, :text)
41
+ Langchain::LLM::AI21Response.new response, model: parameters[:model]
42
42
  end
43
43
 
44
44
  #
@@ -51,6 +51,7 @@ module Langchain::LLM
51
51
  def summarize(text:, **params)
52
52
  response = client.summarize(text, "TEXT", params)
53
53
  response.dig(:summary)
54
+ # Should we update this to also return a Langchain::LLM::AI21Response?
54
55
  end
55
56
 
56
57
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "anthropic", "~> 0.1.0"
9
9
  #
10
10
  # Usage:
11
- # anthorpic = Langchain::LLM::Anthropic.new(api_key:)
11
+ # anthorpic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -32,7 +32,7 @@ module Langchain::LLM
32
32
  #
33
33
  # @param prompt [String] The prompt to generate a completion for
34
34
  # @param params [Hash] extra parameters passed to Anthropic::Client#complete
35
- # @return [String] The completion
35
+ # @return [Langchain::LLM::AnthropicResponse] The completion
36
36
  #
37
37
  def complete(prompt:, **params)
38
38
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -43,7 +43,7 @@ module Langchain::LLM
43
43
  # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
44
44
 
45
45
  response = client.complete(parameters: parameters)
46
- response.dig("completion")
46
+ Langchain::LLM::AnthropicResponse.new(response)
47
47
  end
48
48
 
49
49
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "cohere-ruby", "~> 0.9.6"
9
9
  #
10
10
  # Usage:
11
- # cohere = Langchain::LLM::Cohere.new(api_key: "YOUR_API_KEY")
11
+ # cohere = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
12
12
  #
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
@@ -30,14 +30,15 @@ module Langchain::LLM
30
30
  # Generate an embedding for a given text
31
31
  #
32
32
  # @param text [String] The text to generate an embedding for
33
- # @return [Hash] The embedding
33
+ # @return [Langchain::LLM::CohereResponse] Response object
34
34
  #
35
35
  def embed(text:)
36
36
  response = client.embed(
37
37
  texts: [text],
38
38
  model: @defaults[:embeddings_model_name]
39
39
  )
40
- response.dig("embeddings").first
40
+
41
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
41
42
  end
42
43
 
43
44
  #
@@ -45,7 +46,7 @@ module Langchain::LLM
45
46
  #
46
47
  # @param prompt [String] The prompt to generate a completion for
47
48
  # @param params[:stop_sequences]
48
- # @return [Hash] The completion
49
+ # @return [Langchain::LLM::CohereResponse] Response object
49
50
  #
50
51
  def complete(prompt:, **params)
51
52
  default_params = {
@@ -64,7 +65,7 @@ module Langchain::LLM
64
65
  default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
66
 
66
67
  response = client.generate(**default_params)
67
- response.dig("generations").first.dig("text")
68
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
68
69
  end
69
70
 
70
71
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "google_palm_api", "~> 0.1.3"
9
9
  #
10
10
  # Usage:
11
- # google_palm = Langchain::LLM::GooglePalm.new(api_key: "YOUR_API_KEY")
11
+ # google_palm = Langchain::LLM::GooglePalm.new(api_key: ENV["GOOGLE_PALM_API_KEY"])
12
12
  #
13
13
  class GooglePalm < Base
14
14
  DEFAULTS = {
@@ -34,13 +34,13 @@ module Langchain::LLM
34
34
  # Generate an embedding for a given text
35
35
  #
36
36
  # @param text [String] The text to generate an embedding for
37
- # @return [Array] The embedding
37
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
38
38
  #
39
39
  def embed(text:)
40
- response = client.embed(
41
- text: text
42
- )
43
- response.dig("embedding", "value")
40
+ response = client.embed(text: text)
41
+
42
+ Langchain::LLM::GooglePalmResponse.new response,
43
+ model: @defaults[:embeddings_model_name]
44
44
  end
45
45
 
46
46
  #
@@ -48,7 +48,7 @@ module Langchain::LLM
48
48
  #
49
49
  # @param prompt [String] The prompt to generate a completion for
50
50
  # @param params extra parameters passed to GooglePalmAPI::Client#generate_text
51
- # @return [String] The completion
51
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
52
52
  #
53
53
  def complete(prompt:, **params)
54
54
  default_params = {
@@ -68,7 +68,9 @@ module Langchain::LLM
68
68
  default_params.merge!(params)
69
69
 
70
70
  response = client.generate_text(**default_params)
71
- response.dig("candidates", 0, "output")
71
+
72
+ Langchain::LLM::GooglePalmResponse.new response,
73
+ model: default_params[:model]
72
74
  end
73
75
 
74
76
  #
@@ -79,7 +81,7 @@ module Langchain::LLM
79
81
  # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
80
82
  # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
81
83
  # @param options [Hash] extra parameters passed to GooglePalmAPI::Client#generate_chat_message
82
- # @return [String] The chat completion
84
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
83
85
  #
84
86
  def chat(prompt: "", messages: [], context: "", examples: [], **options)
85
87
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
@@ -108,7 +110,9 @@ module Langchain::LLM
108
110
  response = client.generate_chat_message(**default_params)
109
111
  raise "GooglePalm API returned an error: #{response}" if response.dig("error")
110
112
 
111
- response.dig("candidates", 0, "content")
113
+ Langchain::LLM::GooglePalmResponse.new response,
114
+ model: default_params[:model]
115
+ # TODO: Pass in prompt_tokens: prompt_tokens
112
116
  end
113
117
 
114
118
  #
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "hugging-face", "~> 0.3.4"
9
9
  #
10
10
  # Usage:
11
- # hf = Langchain::LLM::HuggingFace.new(api_key: "YOUR_API_KEY")
11
+ # hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
12
12
  #
13
13
  class HuggingFace < Base
14
14
  # The gem does not currently accept other models:
@@ -34,13 +34,14 @@ module Langchain::LLM
34
34
  # Generate an embedding for a given text
35
35
  #
36
36
  # @param text [String] The text to embed
37
- # @return [Array] The embedding
37
+ # @return [Langchain::LLM::HuggingFaceResponse] Response object
38
38
  #
39
39
  def embed(text:)
40
- client.embedding(
40
+ response = client.embedding(
41
41
  input: text,
42
42
  model: DEFAULTS[:embeddings_model_name]
43
43
  )
44
+ Langchain::LLM::HuggingFaceResponse.new(response, model: DEFAULTS[:embeddings_model_name])
44
45
  end
45
46
  end
46
47
  end
@@ -34,7 +34,7 @@ module Langchain::LLM
34
34
 
35
35
  # @param text [String] The text to embed
36
36
  # @param n_threads [Integer] The number of CPU threads to use
37
- # @return [Array] The embedding
37
+ # @return [Array<Float>] The embedding
38
38
  def embed(text:, n_threads: nil)
39
39
  # contexts are kinda stateful when it comes to embeddings, so allocate one each time
40
40
  context = embedding_context
@@ -22,18 +22,23 @@ module Langchain::LLM
22
22
  @url = url
23
23
  end
24
24
 
25
+ #
25
26
  # Generate the completion for a given prompt
27
+ #
26
28
  # @param prompt [String] The prompt to complete
27
29
  # @param model [String] The model to use
28
30
  # @param options [Hash] The options to use (https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
29
- # @return [String] The completed prompt
31
+ # @return [Langchain::LLM::OllamaResponse] Response object
32
+ #
30
33
  def complete(prompt:, model: nil, **options)
31
34
  response = +""
32
35
 
36
+ model_name = model || DEFAULTS[:completion_model_name]
37
+
33
38
  client.post("api/generate") do |req|
34
39
  req.body = {}
35
40
  req.body["prompt"] = prompt
36
- req.body["model"] = model || DEFAULTS[:completion_model_name]
41
+ req.body["model"] = model_name
37
42
 
38
43
  req.body["options"] = options if options.any?
39
44
 
@@ -47,27 +52,34 @@ module Langchain::LLM
47
52
  end
48
53
  end
49
54
 
50
- response
55
+ Langchain::LLM::OllamaResponse.new(response, model: model_name)
51
56
  end
52
57
 
58
+ #
53
59
  # Generate an embedding for a given text
60
+ #
54
61
  # @param text [String] The text to generate an embedding for
55
62
  # @param model [String] The model to use
56
- # @param options [Hash] The options to use (
63
+ # @param options [Hash] The options to use
64
+ # @return [Langchain::LLM::OllamaResponse] Response object
65
+ #
57
66
  def embed(text:, model: nil, **options)
67
+ model_name = model || DEFAULTS[:embeddings_model_name]
68
+
58
69
  response = client.post("api/embeddings") do |req|
59
70
  req.body = {}
60
71
  req.body["prompt"] = text
61
- req.body["model"] = model || DEFAULTS[:embeddings_model_name]
72
+ req.body["model"] = model_name
62
73
 
63
74
  req.body["options"] = options if options.any?
64
75
  end
65
76
 
66
- response.body.dig("embedding")
77
+ Langchain::LLM::OllamaResponse.new(response.body, model: model_name)
67
78
  end
68
79
 
69
80
  private
70
81
 
82
+ # @return [Faraday::Connection] Faraday client
71
83
  def client
72
84
  @client ||= Faraday.new(url: url) do |conn|
73
85
  conn.request :json
@@ -42,7 +42,7 @@ module Langchain::LLM
42
42
  #
43
43
  # @param text [String] The text to generate an embedding for
44
44
  # @param params extra parameters passed to OpenAI::Client#embeddings
45
- # @return [Array] The embedding
45
+ # @return [Langchain::LLM::OpenAIResponse] Response object
46
46
  #
47
47
  def embed(text:, **params)
48
48
  parameters = {model: @defaults[:embeddings_model_name], input: text}
@@ -53,7 +53,7 @@ module Langchain::LLM
53
53
  client.embeddings(parameters: parameters.merge(params))
54
54
  end
55
55
 
56
- response.dig("data").first.dig("embedding")
56
+ Langchain::LLM::OpenAIResponse.new(response)
57
57
  end
58
58
 
59
59
  #
@@ -61,7 +61,7 @@ module Langchain::LLM
61
61
  #
62
62
  # @param prompt [String] The prompt to generate a completion for
63
63
  # @param params extra parameters passed to OpenAI::Client#complete
64
- # @return [String] The completion
64
+ # @return [Langchain::LLM::Response::OpenaAI] Response object
65
65
  #
66
66
  def complete(prompt:, **params)
67
67
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -75,7 +75,7 @@ module Langchain::LLM
75
75
  client.chat(parameters: parameters)
76
76
  end
77
77
 
78
- response.dig("choices", 0, "message", "content")
78
+ Langchain::LLM::OpenAIResponse.new(response)
79
79
  end
80
80
 
81
81
  #
@@ -120,7 +120,7 @@ module Langchain::LLM
120
120
  # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
121
121
  # @param options [Hash] extra parameters passed to OpenAI::Client#chat
122
122
  # @yield [Hash] Stream responses back one token at a time
123
- # @return [String|Array<String>] The chat completion
123
+ # @return [Langchain::LLM::OpenAIResponse] Response object
124
124
  #
125
125
  def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
126
126
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
@@ -138,7 +138,7 @@ module Langchain::LLM
138
138
 
139
139
  return if block
140
140
 
141
- extract_response response
141
+ Langchain::LLM::OpenAIResponse.new(response)
142
142
  end
143
143
 
144
144
  #
@@ -154,6 +154,7 @@ module Langchain::LLM
154
154
  prompt = prompt_template.format(text: text)
155
155
 
156
156
  complete(prompt: prompt, temperature: @defaults[:temperature])
157
+ # Should this return a Langchain::LLM::OpenAIResponse as well?
157
158
  end
158
159
 
159
160
  private
@@ -47,38 +47,34 @@ module Langchain::LLM
47
47
  # Generate an embedding for a given text
48
48
  #
49
49
  # @param text [String] The text to generate an embedding for
50
- # @return [Hash] The embedding
50
+ # @return [Langchain::LLM::ReplicateResponse] Response object
51
51
  #
52
52
  def embed(text:)
53
53
  response = embeddings_model.predict(input: text)
54
54
 
55
55
  until response.finished?
56
56
  response.refetch
57
- sleep(1)
57
+ sleep(0.1)
58
58
  end
59
59
 
60
- response.output
60
+ Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:embeddings_model_name])
61
61
  end
62
62
 
63
63
  #
64
64
  # Generate a completion for a given prompt
65
65
  #
66
66
  # @param prompt [String] The prompt to generate a completion for
67
- # @return [Hash] The completion
67
+ # @return [Langchain::LLM::ReplicateResponse] Reponse object
68
68
  #
69
69
  def complete(prompt:, **params)
70
70
  response = completion_model.predict(prompt: prompt)
71
71
 
72
72
  until response.finished?
73
73
  response.refetch
74
- sleep(1)
74
+ sleep(0.1)
75
75
  end
76
76
 
77
- # Response comes back as an array of strings, e.g.: ["Hi", "how ", "are ", "you?"]
78
- # The first array element is missing a space at the end, so we add it manually
79
- response.output[0] += " "
80
-
81
- response.output.join
77
+ Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:completion_model_name])
82
78
  end
83
79
 
84
80
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class AI21Response < BaseResponse
5
+ def completions
6
+ raw_response.dig(:completions)
7
+ end
8
+
9
+ def completion
10
+ completions.dig(0, :data, :text)
11
+ end
12
+ end
13
+ end