langchainrb 0.6.17 → 0.6.19

Sign up to get free protection for your applications and to get access to all the features.
Files changed (43) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +15 -0
  3. data/README.md +18 -3
  4. data/lib/langchain/active_record/hooks.rb +14 -0
  5. data/lib/langchain/agent/react_agent.rb +1 -1
  6. data/lib/langchain/agent/sql_query_agent.rb +2 -2
  7. data/lib/langchain/chunk.rb +16 -0
  8. data/lib/langchain/chunker/base.rb +4 -0
  9. data/lib/langchain/chunker/recursive_text.rb +5 -2
  10. data/lib/langchain/chunker/semantic.rb +4 -1
  11. data/lib/langchain/chunker/sentence.rb +4 -2
  12. data/lib/langchain/chunker/text.rb +5 -2
  13. data/lib/langchain/conversation.rb +1 -1
  14. data/lib/langchain/llm/ai21.rb +4 -3
  15. data/lib/langchain/llm/anthropic.rb +3 -3
  16. data/lib/langchain/llm/cohere.rb +6 -5
  17. data/lib/langchain/llm/google_palm.rb +14 -10
  18. data/lib/langchain/llm/hugging_face.rb +4 -3
  19. data/lib/langchain/llm/llama_cpp.rb +1 -1
  20. data/lib/langchain/llm/ollama.rb +18 -6
  21. data/lib/langchain/llm/openai.rb +7 -6
  22. data/lib/langchain/llm/replicate.rb +6 -10
  23. data/lib/langchain/llm/response/ai21_response.rb +13 -0
  24. data/lib/langchain/llm/response/anthropic_response.rb +29 -0
  25. data/lib/langchain/llm/response/base_response.rb +79 -0
  26. data/lib/langchain/llm/response/cohere_response.rb +21 -0
  27. data/lib/langchain/llm/response/google_palm_response.rb +36 -0
  28. data/lib/langchain/llm/response/hugging_face_response.rb +13 -0
  29. data/lib/langchain/llm/response/ollama_response.rb +26 -0
  30. data/lib/langchain/llm/response/openai_response.rb +51 -0
  31. data/lib/langchain/llm/response/replicate_response.rb +28 -0
  32. data/lib/langchain/vectorsearch/base.rb +4 -7
  33. data/lib/langchain/vectorsearch/chroma.rb +13 -12
  34. data/lib/langchain/vectorsearch/elasticsearch.rb +147 -0
  35. data/lib/langchain/vectorsearch/hnswlib.rb +5 -5
  36. data/lib/langchain/vectorsearch/milvus.rb +5 -4
  37. data/lib/langchain/vectorsearch/pgvector.rb +12 -6
  38. data/lib/langchain/vectorsearch/pinecone.rb +14 -13
  39. data/lib/langchain/vectorsearch/qdrant.rb +9 -8
  40. data/lib/langchain/vectorsearch/weaviate.rb +9 -8
  41. data/lib/langchain/version.rb +1 -1
  42. data/lib/langchain.rb +5 -0
  43. metadata +27 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3b9bca59bfb5909f6ac24ebf6dba6074f5faf3d2cdadab1a3b3a8a0f75f98adc
4
- data.tar.gz: a202726d383d2dc691cb4146e9b36cb7ea6f8ac35382a3df67f6e11d35b3562e
3
+ metadata.gz: d7be5e031274fba7a4c0d7fc2cd3f472ed83fb66d8c6b355fb71fbf69a825b73
4
+ data.tar.gz: 745cbc4f3d7b569d2e1407acc8be123f77a0aac2964840d7c3dca215592811ee
5
5
  SHA512:
6
- metadata.gz: b4eaf631f22236035c9e29b3618a70d14487cc9e39b6885e44497ebad2a98670ce88997fdb25144b6467e0caa69a04ce7e625c9e10bc88322131181c2254a570
7
- data.tar.gz: 981199fe2a0123e46ac3af54946c03d5eaa827473eae02f2e60accd0c680a0bbd40741800e05b79e890038523a1b910502a6cf4ed1f4ebf77845f4b2a2dbc5d9
6
+ metadata.gz: e1392abe2fc0c4928593bd77d0e62688e3959ec39fd3f7bb5effc784b47599402c611ecc545868178b5d04ec688d68d6406f220697e8bfe40771cc593292a192
7
+ data.tar.gz: 926bccf20c71af3d31d942cf439336df9edc489a8e5e0359a6c24bb26e5b818be048a7ef63ebcce721bb99392b49e407288ffdd7387dd33d3f0161e92ff6e045
data/CHANGELOG.md CHANGED
@@ -1,5 +1,20 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.6.19] - 2023-10-18
4
+ - Elasticsearch vector search support
5
+ - Fix `lib/langchain/railtie.rb` not being loaded with the gem
6
+
7
+ ## [0.6.18] - 2023-10-16
8
+ - Introduce `Langchain::LLM::Response`` object
9
+ - Introduce `Langchain::Chunk` object
10
+ - Add the ask() method to the Langchain::ActiveRecord::Hooks
11
+
12
+ ## [0.6.17] - 2023-10-10
13
+ - Bump weaviate and chroma-db deps
14
+ - `Langchain::Chunker::Semantic` chunker
15
+ - Re-structure Conversations class
16
+ - Bug fixes
17
+
3
18
  ## [0.6.16] - 2023-10-02
4
19
  - HyDE-style similarity search
5
20
  - `Langchain::Chunker::Sentence` chunker
data/README.md CHANGED
@@ -19,11 +19,11 @@ Langchain.rb is a library that's an abstraction layer on top many emergent AI, M
19
19
 
20
20
  Install the gem and add to the application's Gemfile by executing:
21
21
 
22
- $ bundle add langchainrb
22
+ bundle add langchainrb
23
23
 
24
24
  If bundler is not being used to manage dependencies, install the gem by executing:
25
25
 
26
- $ gem install langchainrb
26
+ gem install langchainrb
27
27
 
28
28
  ## Usage
29
29
 
@@ -37,7 +37,7 @@ require "langchain"
37
37
  | -------- |:------------------:| -------:| -----------------:| -------:| -----------------:|
38
38
  | [Chroma](https://trychroma.com/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
39
39
  | [Hnswlib](https://github.com/nmslib/hnswlib/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
40
- | [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | WIP |
40
+ | [Milvus](https://milvus.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
41
41
  | [Pinecone](https://www.pinecone.io/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
42
42
  | [Pgvector](https://github.com/pgvector/pgvector) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
43
43
  | [Qdrant](https://qdrant.tech/) | :white_check_mark: | :white_check_mark: | :white_check_mark: | WIP | :white_check_mark: |
@@ -128,6 +128,21 @@ class Product < ActiveRecord::Base
128
128
  end
129
129
  ```
130
130
 
131
+ ### Exposed ActiveRecord methods
132
+ ```ruby
133
+ # Retrieve similar products based on the query string passed in
134
+ Product.similarity_search(
135
+ query:,
136
+ k: # number of results to be retrieved
137
+ )
138
+ ```
139
+ ```ruby
140
+ # Q&A-style querying based on the question passed in
141
+ Product.ask(
142
+ question:
143
+ )
144
+ ```
145
+
131
146
  Additional info [here](https://github.com/andreibondarev/langchainrb/blob/main/lib/langchain/active_record/hooks.rb#L10-L38).
132
147
 
133
148
  ### Using Standalone LLMs 🗣️
@@ -92,6 +92,20 @@ module Langchain
92
92
  ids = records.map { |record| record.dig("id") || record.dig("__id") }
93
93
  where(id: ids)
94
94
  end
95
+
96
+ # Ask a question and return the answer
97
+ #
98
+ # @param question [String] The question to ask
99
+ # @param k [Integer] The number of results to have in context
100
+ # @yield [String] Stream responses back one String at a time
101
+ # @return [String] The answer to the question
102
+ def ask(question:, k: 4, &block)
103
+ class_variable_get(:@@provider).ask(
104
+ question: question,
105
+ k: k,
106
+ &block
107
+ )
108
+ end
95
109
  end
96
110
  end
97
111
  end
@@ -58,7 +58,7 @@ module Langchain::Agent
58
58
  max_iterations.times do
59
59
  Langchain.logger.info("Sending the prompt to the #{llm.class} LLM", for: self.class)
60
60
 
61
- response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
61
+ response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"]).completion
62
62
 
63
63
  # Append the response to the prompt
64
64
  prompt += response
@@ -27,7 +27,7 @@ module Langchain::Agent
27
27
 
28
28
  # Get the SQL string to execute
29
29
  Langchain.logger.info("Passing the inital prompt to the #{llm.class} LLM", for: self.class)
30
- sql_string = llm.complete(prompt: prompt)
30
+ sql_string = llm.complete(prompt: prompt).completion
31
31
 
32
32
  # Execute the SQL string and collect the results
33
33
  Langchain.logger.info("Passing the SQL to the Database: #{sql_string}", for: self.class)
@@ -36,7 +36,7 @@ module Langchain::Agent
36
36
  # Pass the results and get the LLM to synthesize the answer to the question
37
37
  Langchain.logger.info("Passing the synthesize prompt to the #{llm.class} LLM with results: #{results}", for: self.class)
38
38
  prompt2 = create_prompt_for_answer(question: question, sql_query: sql_string, results: results)
39
- llm.complete(prompt: prompt2)
39
+ llm.complete(prompt: prompt2).completion
40
40
  end
41
41
 
42
42
  private
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Chunk
5
+ # The chunking process is the process of splitting a document into smaller chunks and creating instances of Langchain::Chunk
6
+
7
+ attr_reader :text
8
+
9
+ # Initialize a new chunk
10
+ # @param [String] text
11
+ # @return [Langchain::Chunk]
12
+ def initialize(text:)
13
+ @text = text
14
+ end
15
+ end
16
+ end
@@ -13,6 +13,10 @@ module Langchain
13
13
  # - {Langchain::Chunker::Semantic}
14
14
  # - {Langchain::Chunker::Sentence}
15
15
  class Base
16
+ # @return [Array<Langchain::Chunk>]
17
+ def chunks
18
+ raise NotImplementedError
19
+ end
16
20
  end
17
21
  end
18
22
  end
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separators = separators
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::RecursiveCharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separators: separators
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -23,7 +23,7 @@ module Langchain
23
23
  @prompt_template = prompt_template || default_prompt_template
24
24
  end
25
25
 
26
- # @return [Array<String>]
26
+ # @return [Array<Langchain::Chunk>]
27
27
  def chunks
28
28
  prompt = prompt_template.format(text: text)
29
29
 
@@ -34,6 +34,9 @@ module Langchain
34
34
  .split("---")
35
35
  .map(&:strip)
36
36
  .reject(&:empty?)
37
+ .map do |chunk|
38
+ Langchain::Chunk.new(text: chunk)
39
+ end
37
40
  end
38
41
 
39
42
  private
@@ -19,10 +19,12 @@ module Langchain
19
19
  @text = text
20
20
  end
21
21
 
22
- # @return [Array<String>]
22
+ # @return [Array<Langchain::Chunk>]
23
23
  def chunks
24
24
  ps = PragmaticSegmenter::Segmenter.new(text: text)
25
- ps.segment
25
+ ps.segment.map do |chunk|
26
+ Langchain::Chunk.new(text: chunk)
27
+ end
26
28
  end
27
29
  end
28
30
  end
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separator = separator
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::CharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separator: separator
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -58,7 +58,7 @@ module Langchain
58
58
  # @return [Response] The response from the model
59
59
  def message(message)
60
60
  @memory.append_message ::Langchain::Conversation::Prompt.new(message)
61
- ai_message = ::Langchain::Conversation::Response.new(llm_response)
61
+ ai_message = ::Langchain::Conversation::Response.new(llm_response.chat_completion)
62
62
  @memory.append_message(ai_message)
63
63
  ai_message
64
64
  end
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "ai21", "~> 0.2.1"
9
9
  #
10
10
  # Usage:
11
- # ai21 = Langchain::LLM::AI21.new(api_key:)
11
+ # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
12
12
  #
13
13
  class AI21 < Base
14
14
  DEFAULTS = {
@@ -30,7 +30,7 @@ module Langchain::LLM
30
30
  #
31
31
  # @param prompt [String] The prompt to generate a completion for
32
32
  # @param params [Hash] The parameters to pass to the API
33
- # @return [String] The completion
33
+ # @return [Langchain::LLM::AI21Response] The completion
34
34
  #
35
35
  def complete(prompt:, **params)
36
36
  parameters = complete_parameters params
@@ -38,7 +38,7 @@ module Langchain::LLM
38
38
  parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], client)
39
39
 
40
40
  response = client.complete(prompt, parameters)
41
- response.dig(:completions, 0, :data, :text)
41
+ Langchain::LLM::AI21Response.new response, model: parameters[:model]
42
42
  end
43
43
 
44
44
  #
@@ -51,6 +51,7 @@ module Langchain::LLM
51
51
  def summarize(text:, **params)
52
52
  response = client.summarize(text, "TEXT", params)
53
53
  response.dig(:summary)
54
+ # Should we update this to also return a Langchain::LLM::AI21Response?
54
55
  end
55
56
 
56
57
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "anthropic", "~> 0.1.0"
9
9
  #
10
10
  # Usage:
11
- # anthorpic = Langchain::LLM::Anthropic.new(api_key:)
11
+ # anthorpic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -32,7 +32,7 @@ module Langchain::LLM
32
32
  #
33
33
  # @param prompt [String] The prompt to generate a completion for
34
34
  # @param params [Hash] extra parameters passed to Anthropic::Client#complete
35
- # @return [String] The completion
35
+ # @return [Langchain::LLM::AnthropicResponse] The completion
36
36
  #
37
37
  def complete(prompt:, **params)
38
38
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -43,7 +43,7 @@ module Langchain::LLM
43
43
  # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
44
44
 
45
45
  response = client.complete(parameters: parameters)
46
- response.dig("completion")
46
+ Langchain::LLM::AnthropicResponse.new(response)
47
47
  end
48
48
 
49
49
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "cohere-ruby", "~> 0.9.6"
9
9
  #
10
10
  # Usage:
11
- # cohere = Langchain::LLM::Cohere.new(api_key: "YOUR_API_KEY")
11
+ # cohere = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
12
12
  #
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
@@ -30,14 +30,15 @@ module Langchain::LLM
30
30
  # Generate an embedding for a given text
31
31
  #
32
32
  # @param text [String] The text to generate an embedding for
33
- # @return [Hash] The embedding
33
+ # @return [Langchain::LLM::CohereResponse] Response object
34
34
  #
35
35
  def embed(text:)
36
36
  response = client.embed(
37
37
  texts: [text],
38
38
  model: @defaults[:embeddings_model_name]
39
39
  )
40
- response.dig("embeddings").first
40
+
41
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
41
42
  end
42
43
 
43
44
  #
@@ -45,7 +46,7 @@ module Langchain::LLM
45
46
  #
46
47
  # @param prompt [String] The prompt to generate a completion for
47
48
  # @param params[:stop_sequences]
48
- # @return [Hash] The completion
49
+ # @return [Langchain::LLM::CohereResponse] Response object
49
50
  #
50
51
  def complete(prompt:, **params)
51
52
  default_params = {
@@ -64,7 +65,7 @@ module Langchain::LLM
64
65
  default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
66
 
66
67
  response = client.generate(**default_params)
67
- response.dig("generations").first.dig("text")
68
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
68
69
  end
69
70
 
70
71
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "google_palm_api", "~> 0.1.3"
9
9
  #
10
10
  # Usage:
11
- # google_palm = Langchain::LLM::GooglePalm.new(api_key: "YOUR_API_KEY")
11
+ # google_palm = Langchain::LLM::GooglePalm.new(api_key: ENV["GOOGLE_PALM_API_KEY"])
12
12
  #
13
13
  class GooglePalm < Base
14
14
  DEFAULTS = {
@@ -34,13 +34,13 @@ module Langchain::LLM
34
34
  # Generate an embedding for a given text
35
35
  #
36
36
  # @param text [String] The text to generate an embedding for
37
- # @return [Array] The embedding
37
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
38
38
  #
39
39
  def embed(text:)
40
- response = client.embed(
41
- text: text
42
- )
43
- response.dig("embedding", "value")
40
+ response = client.embed(text: text)
41
+
42
+ Langchain::LLM::GooglePalmResponse.new response,
43
+ model: @defaults[:embeddings_model_name]
44
44
  end
45
45
 
46
46
  #
@@ -48,7 +48,7 @@ module Langchain::LLM
48
48
  #
49
49
  # @param prompt [String] The prompt to generate a completion for
50
50
  # @param params extra parameters passed to GooglePalmAPI::Client#generate_text
51
- # @return [String] The completion
51
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
52
52
  #
53
53
  def complete(prompt:, **params)
54
54
  default_params = {
@@ -68,7 +68,9 @@ module Langchain::LLM
68
68
  default_params.merge!(params)
69
69
 
70
70
  response = client.generate_text(**default_params)
71
- response.dig("candidates", 0, "output")
71
+
72
+ Langchain::LLM::GooglePalmResponse.new response,
73
+ model: default_params[:model]
72
74
  end
73
75
 
74
76
  #
@@ -79,7 +81,7 @@ module Langchain::LLM
79
81
  # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
80
82
  # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
81
83
  # @param options [Hash] extra parameters passed to GooglePalmAPI::Client#generate_chat_message
82
- # @return [String] The chat completion
84
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
83
85
  #
84
86
  def chat(prompt: "", messages: [], context: "", examples: [], **options)
85
87
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
@@ -108,7 +110,9 @@ module Langchain::LLM
108
110
  response = client.generate_chat_message(**default_params)
109
111
  raise "GooglePalm API returned an error: #{response}" if response.dig("error")
110
112
 
111
- response.dig("candidates", 0, "content")
113
+ Langchain::LLM::GooglePalmResponse.new response,
114
+ model: default_params[:model]
115
+ # TODO: Pass in prompt_tokens: prompt_tokens
112
116
  end
113
117
 
114
118
  #
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "hugging-face", "~> 0.3.4"
9
9
  #
10
10
  # Usage:
11
- # hf = Langchain::LLM::HuggingFace.new(api_key: "YOUR_API_KEY")
11
+ # hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
12
12
  #
13
13
  class HuggingFace < Base
14
14
  # The gem does not currently accept other models:
@@ -34,13 +34,14 @@ module Langchain::LLM
34
34
  # Generate an embedding for a given text
35
35
  #
36
36
  # @param text [String] The text to embed
37
- # @return [Array] The embedding
37
+ # @return [Langchain::LLM::HuggingFaceResponse] Response object
38
38
  #
39
39
  def embed(text:)
40
- client.embedding(
40
+ response = client.embedding(
41
41
  input: text,
42
42
  model: DEFAULTS[:embeddings_model_name]
43
43
  )
44
+ Langchain::LLM::HuggingFaceResponse.new(response, model: DEFAULTS[:embeddings_model_name])
44
45
  end
45
46
  end
46
47
  end
@@ -34,7 +34,7 @@ module Langchain::LLM
34
34
 
35
35
  # @param text [String] The text to embed
36
36
  # @param n_threads [Integer] The number of CPU threads to use
37
- # @return [Array] The embedding
37
+ # @return [Array<Float>] The embedding
38
38
  def embed(text:, n_threads: nil)
39
39
  # contexts are kinda stateful when it comes to embeddings, so allocate one each time
40
40
  context = embedding_context
@@ -22,18 +22,23 @@ module Langchain::LLM
22
22
  @url = url
23
23
  end
24
24
 
25
+ #
25
26
  # Generate the completion for a given prompt
27
+ #
26
28
  # @param prompt [String] The prompt to complete
27
29
  # @param model [String] The model to use
28
30
  # @param options [Hash] The options to use (https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
29
- # @return [String] The completed prompt
31
+ # @return [Langchain::LLM::OllamaResponse] Response object
32
+ #
30
33
  def complete(prompt:, model: nil, **options)
31
34
  response = +""
32
35
 
36
+ model_name = model || DEFAULTS[:completion_model_name]
37
+
33
38
  client.post("api/generate") do |req|
34
39
  req.body = {}
35
40
  req.body["prompt"] = prompt
36
- req.body["model"] = model || DEFAULTS[:completion_model_name]
41
+ req.body["model"] = model_name
37
42
 
38
43
  req.body["options"] = options if options.any?
39
44
 
@@ -47,27 +52,34 @@ module Langchain::LLM
47
52
  end
48
53
  end
49
54
 
50
- response
55
+ Langchain::LLM::OllamaResponse.new(response, model: model_name)
51
56
  end
52
57
 
58
+ #
53
59
  # Generate an embedding for a given text
60
+ #
54
61
  # @param text [String] The text to generate an embedding for
55
62
  # @param model [String] The model to use
56
- # @param options [Hash] The options to use (
63
+ # @param options [Hash] The options to use
64
+ # @return [Langchain::LLM::OllamaResponse] Response object
65
+ #
57
66
  def embed(text:, model: nil, **options)
67
+ model_name = model || DEFAULTS[:embeddings_model_name]
68
+
58
69
  response = client.post("api/embeddings") do |req|
59
70
  req.body = {}
60
71
  req.body["prompt"] = text
61
- req.body["model"] = model || DEFAULTS[:embeddings_model_name]
72
+ req.body["model"] = model_name
62
73
 
63
74
  req.body["options"] = options if options.any?
64
75
  end
65
76
 
66
- response.body.dig("embedding")
77
+ Langchain::LLM::OllamaResponse.new(response.body, model: model_name)
67
78
  end
68
79
 
69
80
  private
70
81
 
82
+ # @return [Faraday::Connection] Faraday client
71
83
  def client
72
84
  @client ||= Faraday.new(url: url) do |conn|
73
85
  conn.request :json
@@ -42,7 +42,7 @@ module Langchain::LLM
42
42
  #
43
43
  # @param text [String] The text to generate an embedding for
44
44
  # @param params extra parameters passed to OpenAI::Client#embeddings
45
- # @return [Array] The embedding
45
+ # @return [Langchain::LLM::OpenAIResponse] Response object
46
46
  #
47
47
  def embed(text:, **params)
48
48
  parameters = {model: @defaults[:embeddings_model_name], input: text}
@@ -53,7 +53,7 @@ module Langchain::LLM
53
53
  client.embeddings(parameters: parameters.merge(params))
54
54
  end
55
55
 
56
- response.dig("data").first.dig("embedding")
56
+ Langchain::LLM::OpenAIResponse.new(response)
57
57
  end
58
58
 
59
59
  #
@@ -61,7 +61,7 @@ module Langchain::LLM
61
61
  #
62
62
  # @param prompt [String] The prompt to generate a completion for
63
63
  # @param params extra parameters passed to OpenAI::Client#complete
64
- # @return [String] The completion
64
+ # @return [Langchain::LLM::Response::OpenaAI] Response object
65
65
  #
66
66
  def complete(prompt:, **params)
67
67
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -75,7 +75,7 @@ module Langchain::LLM
75
75
  client.chat(parameters: parameters)
76
76
  end
77
77
 
78
- response.dig("choices", 0, "message", "content")
78
+ Langchain::LLM::OpenAIResponse.new(response)
79
79
  end
80
80
 
81
81
  #
@@ -120,7 +120,7 @@ module Langchain::LLM
120
120
  # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
121
121
  # @param options [Hash] extra parameters passed to OpenAI::Client#chat
122
122
  # @yield [Hash] Stream responses back one token at a time
123
- # @return [String|Array<String>] The chat completion
123
+ # @return [Langchain::LLM::OpenAIResponse] Response object
124
124
  #
125
125
  def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
126
126
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
@@ -138,7 +138,7 @@ module Langchain::LLM
138
138
 
139
139
  return if block
140
140
 
141
- extract_response response
141
+ Langchain::LLM::OpenAIResponse.new(response)
142
142
  end
143
143
 
144
144
  #
@@ -154,6 +154,7 @@ module Langchain::LLM
154
154
  prompt = prompt_template.format(text: text)
155
155
 
156
156
  complete(prompt: prompt, temperature: @defaults[:temperature])
157
+ # Should this return a Langchain::LLM::OpenAIResponse as well?
157
158
  end
158
159
 
159
160
  private
@@ -47,38 +47,34 @@ module Langchain::LLM
47
47
  # Generate an embedding for a given text
48
48
  #
49
49
  # @param text [String] The text to generate an embedding for
50
- # @return [Hash] The embedding
50
+ # @return [Langchain::LLM::ReplicateResponse] Response object
51
51
  #
52
52
  def embed(text:)
53
53
  response = embeddings_model.predict(input: text)
54
54
 
55
55
  until response.finished?
56
56
  response.refetch
57
- sleep(1)
57
+ sleep(0.1)
58
58
  end
59
59
 
60
- response.output
60
+ Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:embeddings_model_name])
61
61
  end
62
62
 
63
63
  #
64
64
  # Generate a completion for a given prompt
65
65
  #
66
66
  # @param prompt [String] The prompt to generate a completion for
67
- # @return [Hash] The completion
67
+ # @return [Langchain::LLM::ReplicateResponse] Reponse object
68
68
  #
69
69
  def complete(prompt:, **params)
70
70
  response = completion_model.predict(prompt: prompt)
71
71
 
72
72
  until response.finished?
73
73
  response.refetch
74
- sleep(1)
74
+ sleep(0.1)
75
75
  end
76
76
 
77
- # Response comes back as an array of strings, e.g.: ["Hi", "how ", "are ", "you?"]
78
- # The first array element is missing a space at the end, so we add it manually
79
- response.output[0] += " "
80
-
81
- response.output.join
77
+ Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:completion_model_name])
82
78
  end
83
79
 
84
80
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class AI21Response < BaseResponse
5
+ def completions
6
+ raw_response.dig(:completions)
7
+ end
8
+
9
+ def completion
10
+ completions.dig(0, :data, :text)
11
+ end
12
+ end
13
+ end