langchainrb 0.6.16 → 0.6.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +11 -0
  3. data/README.md +16 -1
  4. data/lib/langchain/active_record/hooks.rb +14 -0
  5. data/lib/langchain/agent/react_agent.rb +1 -1
  6. data/lib/langchain/agent/sql_query_agent.rb +2 -2
  7. data/lib/langchain/chunk.rb +16 -0
  8. data/lib/langchain/chunker/base.rb +7 -0
  9. data/lib/langchain/chunker/prompts/semantic_prompt_template.yml +8 -0
  10. data/lib/langchain/chunker/recursive_text.rb +5 -2
  11. data/lib/langchain/chunker/semantic.rb +52 -0
  12. data/lib/langchain/chunker/sentence.rb +4 -2
  13. data/lib/langchain/chunker/text.rb +5 -2
  14. data/lib/langchain/{ai_message.rb → conversation/context.rb} +2 -3
  15. data/lib/langchain/conversation/memory.rb +86 -0
  16. data/lib/langchain/conversation/message.rb +48 -0
  17. data/lib/langchain/{human_message.rb → conversation/prompt.rb} +2 -3
  18. data/lib/langchain/{system_message.rb → conversation/response.rb} +2 -3
  19. data/lib/langchain/conversation.rb +11 -12
  20. data/lib/langchain/llm/ai21.rb +4 -3
  21. data/lib/langchain/llm/anthropic.rb +3 -3
  22. data/lib/langchain/llm/cohere.rb +7 -6
  23. data/lib/langchain/llm/google_palm.rb +24 -20
  24. data/lib/langchain/llm/hugging_face.rb +4 -3
  25. data/lib/langchain/llm/llama_cpp.rb +1 -1
  26. data/lib/langchain/llm/ollama.rb +18 -6
  27. data/lib/langchain/llm/openai.rb +38 -41
  28. data/lib/langchain/llm/replicate.rb +7 -11
  29. data/lib/langchain/llm/response/ai21_response.rb +13 -0
  30. data/lib/langchain/llm/response/anthropic_response.rb +29 -0
  31. data/lib/langchain/llm/response/base_response.rb +79 -0
  32. data/lib/langchain/llm/response/cohere_response.rb +21 -0
  33. data/lib/langchain/llm/response/google_palm_response.rb +36 -0
  34. data/lib/langchain/llm/response/hugging_face_response.rb +13 -0
  35. data/lib/langchain/llm/response/ollama_response.rb +26 -0
  36. data/lib/langchain/llm/response/openai_response.rb +51 -0
  37. data/lib/langchain/llm/response/replicate_response.rb +28 -0
  38. data/lib/langchain/vectorsearch/base.rb +1 -1
  39. data/lib/langchain/vectorsearch/chroma.rb +11 -12
  40. data/lib/langchain/vectorsearch/hnswlib.rb +5 -5
  41. data/lib/langchain/vectorsearch/milvus.rb +2 -2
  42. data/lib/langchain/vectorsearch/pgvector.rb +3 -3
  43. data/lib/langchain/vectorsearch/pinecone.rb +10 -10
  44. data/lib/langchain/vectorsearch/qdrant.rb +5 -5
  45. data/lib/langchain/vectorsearch/weaviate.rb +6 -6
  46. data/lib/langchain/version.rb +1 -1
  47. data/lib/langchain.rb +3 -1
  48. metadata +23 -11
  49. data/lib/langchain/conversation_memory.rb +0 -84
  50. data/lib/langchain/message.rb +0 -35
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 36e0bec4ad6abfd9077c9e7f2d6166ba99acb7dc3859749ee6facfb9409e6379
4
- data.tar.gz: 6bd8d3de4f1d31b718381fcef1c21a8b417b2bd8483d7fdc2610cfda3b60a50e
3
+ metadata.gz: 437c6387ded139ed1a513414bfb7242cdbadf1ba6526c7a89346aa2fa9490fc2
4
+ data.tar.gz: dd6f437a4bbc4807a16631dd790f66c9de4e9456011b2c4f84302fe3fab1377b
5
5
  SHA512:
6
- metadata.gz: ed7be8f193d44075f701622fd991127ab32580293fb6d1ab7ccc096eeff8704312ad34cdb7a4cfd09cf8879116ede17a5b017fe15851b9ee78cb159b7e8d8b59
7
- data.tar.gz: f70d7a3707ed7fce123c2f9158c338cda3aa38a46abf5598f7d05c6ccd63d5a16a37ba10ff0a7a0a4cd17c0c2aeb2f07a07842a41f16322c48c7c9bae522dda4
6
+ metadata.gz: 24748539de50dfa816fdb71173ef00a6b04f9737f32926fca919865a49b9812dd9f1fdb286c361c98e33cc994f67e8988ab688bfdf6bf3020d954eb0c791177c
7
+ data.tar.gz: 283b10460187cada7485e08a19c89e7485925ab2f73a5ad51b06a72e8fd9ee1600ddac9d000f13c0c1af13f6defece9fdcc272489d0df803f94da96fe1c76cfd
data/CHANGELOG.md CHANGED
@@ -1,5 +1,16 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.6.18] - 2023-10-16
4
+ - Introduce `Langchain::LLM::Response`` object
5
+ - Introduce `Langchain::Chunk` object
6
+ - Add the ask() method to the Langchain::ActiveRecord::Hooks
7
+
8
+ ## [0.6.17] - 2023-10-10
9
+ - Bump weaviate and chroma-db deps
10
+ - `Langchain::Chunker::Semantic` chunker
11
+ - Re-structure Conversations class
12
+ - Bug fixes
13
+
3
14
  ## [0.6.16] - 2023-10-02
4
15
  - HyDE-style similarity search
5
16
  - `Langchain::Chunker::Sentence` chunker
data/README.md CHANGED
@@ -59,7 +59,7 @@ client = Langchain::Vectorsearch::Weaviate.new(
59
59
  )
60
60
 
61
61
  # You can instantiate any other supported vector search database:
62
- client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
62
+ client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.6.0"`
63
63
  client = Langchain::Vectorsearch::Hnswlib.new(...) # `gem "hnswlib", "~> 0.8.1"`
64
64
  client = Langchain::Vectorsearch::Milvus.new(...) # `gem "milvus", "~> 0.9.2"`
65
65
  client = Langchain::Vectorsearch::Pinecone.new(...) # `gem "pinecone", "~> 0.1.6"`
@@ -128,6 +128,21 @@ class Product < ActiveRecord::Base
128
128
  end
129
129
  ```
130
130
 
131
+ ### Exposed ActiveRecord methods
132
+ ```ruby
133
+ # Retrieve similar products based on the query string passed in
134
+ Product.similarity_search(
135
+ query:,
136
+ k: # number of results to be retrieved
137
+ )
138
+ ```
139
+ ```ruby
140
+ # Q&A-style querying based on the question passed in
141
+ Product.ask(
142
+ question:
143
+ )
144
+ ```
145
+
131
146
  Additional info [here](https://github.com/andreibondarev/langchainrb/blob/main/lib/langchain/active_record/hooks.rb#L10-L38).
132
147
 
133
148
  ### Using Standalone LLMs 🗣️
@@ -92,6 +92,20 @@ module Langchain
92
92
  ids = records.map { |record| record.dig("id") || record.dig("__id") }
93
93
  where(id: ids)
94
94
  end
95
+
96
+ # Ask a question and return the answer
97
+ #
98
+ # @param question [String] The question to ask
99
+ # @param k [Integer] The number of results to have in context
100
+ # @yield [String] Stream responses back one String at a time
101
+ # @return [String] The answer to the question
102
+ def ask(question:, k: 4, &block)
103
+ class_variable_get(:@@provider).ask(
104
+ question: question,
105
+ k: k,
106
+ &block
107
+ )
108
+ end
95
109
  end
96
110
  end
97
111
  end
@@ -58,7 +58,7 @@ module Langchain::Agent
58
58
  max_iterations.times do
59
59
  Langchain.logger.info("Sending the prompt to the #{llm.class} LLM", for: self.class)
60
60
 
61
- response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
61
+ response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"]).completion
62
62
 
63
63
  # Append the response to the prompt
64
64
  prompt += response
@@ -27,7 +27,7 @@ module Langchain::Agent
27
27
 
28
28
  # Get the SQL string to execute
29
29
  Langchain.logger.info("Passing the inital prompt to the #{llm.class} LLM", for: self.class)
30
- sql_string = llm.complete(prompt: prompt)
30
+ sql_string = llm.complete(prompt: prompt).completion
31
31
 
32
32
  # Execute the SQL string and collect the results
33
33
  Langchain.logger.info("Passing the SQL to the Database: #{sql_string}", for: self.class)
@@ -36,7 +36,7 @@ module Langchain::Agent
36
36
  # Pass the results and get the LLM to synthesize the answer to the question
37
37
  Langchain.logger.info("Passing the synthesize prompt to the #{llm.class} LLM with results: #{results}", for: self.class)
38
38
  prompt2 = create_prompt_for_answer(question: question, sql_query: sql_string, results: results)
39
- llm.complete(prompt: prompt2)
39
+ llm.complete(prompt: prompt2).completion
40
40
  end
41
41
 
42
42
  private
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Chunk
5
+ # The chunking process is the process of splitting a document into smaller chunks and creating instances of Langchain::Chunk
6
+
7
+ attr_reader :text
8
+
9
+ # Initialize a new chunk
10
+ # @param [String] text
11
+ # @return [Langchain::Chunk]
12
+ def initialize(text:)
13
+ @text = text
14
+ end
15
+ end
16
+ end
@@ -8,8 +8,15 @@ module Langchain
8
8
  #
9
9
  # == Available chunkers
10
10
  #
11
+ # - {Langchain::Chunker::RecursiveText}
11
12
  # - {Langchain::Chunker::Text}
13
+ # - {Langchain::Chunker::Semantic}
14
+ # - {Langchain::Chunker::Sentence}
12
15
  class Base
16
+ # @return [Array<Langchain::Chunk>]
17
+ def chunks
18
+ raise NotImplementedError
19
+ end
13
20
  end
14
21
  end
15
22
  end
@@ -0,0 +1,8 @@
1
+ _type: prompt
2
+ input_variables:
3
+ - text
4
+ template: |
5
+ Please split the following text by topics.
6
+ Output only the paragraphs delimited by "---":
7
+
8
+ {text}
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separators = separators
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::RecursiveCharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separators: separators
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -0,0 +1,52 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Chunker
5
+ #
6
+ # LLM-powered semantic chunker.
7
+ # Semantic chunking is a technique of splitting texts by their semantic meaning, e.g.: themes, topics, and ideas.
8
+ # We use an LLM to accomplish this. The Anthropic LLM is highly recommended for this task as it has the longest context window (100k tokens).
9
+ #
10
+ # Usage:
11
+ # Langchain::Chunker::Semantic.new(
12
+ # text,
13
+ # llm: Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
14
+ # ).chunks
15
+ #
16
+ class Semantic < Base
17
+ attr_reader :text, :llm, :prompt_template
18
+ # @param [Langchain::LLM::Base] Langchain::LLM::* instance
19
+ # @param [Langchain::Prompt::PromptTemplate] Optional custom prompt template
20
+ def initialize(text, llm:, prompt_template: nil)
21
+ @text = text
22
+ @llm = llm
23
+ @prompt_template = prompt_template || default_prompt_template
24
+ end
25
+
26
+ # @return [Array<Langchain::Chunk>]
27
+ def chunks
28
+ prompt = prompt_template.format(text: text)
29
+
30
+ # Replace static 50k limit with dynamic limit based on text length (max_tokens_to_sample)
31
+ completion = llm.complete(prompt: prompt, max_tokens_to_sample: 50000)
32
+ completion
33
+ .gsub("Here are the paragraphs split by topic:\n\n", "")
34
+ .split("---")
35
+ .map(&:strip)
36
+ .reject(&:empty?)
37
+ .map do |chunk|
38
+ Langchain::Chunk.new(text: chunk)
39
+ end
40
+ end
41
+
42
+ private
43
+
44
+ # @return [Langchain::Prompt::PromptTemplate] Default prompt template for semantic chunking
45
+ def default_prompt_template
46
+ Langchain::Prompt.load_from_path(
47
+ file_path: Langchain.root.join("langchain/chunker/prompts/semantic_prompt_template.yml")
48
+ )
49
+ end
50
+ end
51
+ end
52
+ end
@@ -19,10 +19,12 @@ module Langchain
19
19
  @text = text
20
20
  end
21
21
 
22
- # @return [Array<String>]
22
+ # @return [Array<Langchain::Chunk>]
23
23
  def chunks
24
24
  ps = PragmaticSegmenter::Segmenter.new(text: text)
25
- ps.segment
25
+ ps.segment.map do |chunk|
26
+ Langchain::Chunk.new(text: chunk)
27
+ end
26
28
  end
27
29
  end
28
30
  end
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separator = separator
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::CharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separator: separator
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -1,9 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- class AIMessage < Message
5
- def type
6
- "ai"
4
+ class Conversation
5
+ class Context < Message
7
6
  end
8
7
  end
9
8
  end
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Conversation
5
+ class Memory
6
+ attr_reader :examples, :messages
7
+
8
+ # The least number of tokens we want to be under the limit by
9
+ TOKEN_LEEWAY = 20
10
+
11
+ def initialize(llm:, messages: [], **options)
12
+ @llm = llm
13
+ @context = nil
14
+ @summary = nil
15
+ @examples = []
16
+ @messages = messages
17
+ @strategy = options.delete(:strategy) || :truncate
18
+ @options = options
19
+ end
20
+
21
+ def set_context(message)
22
+ @context = message
23
+ end
24
+
25
+ def add_examples(examples)
26
+ @examples.concat examples
27
+ end
28
+
29
+ def append_message(message)
30
+ @messages.append(message)
31
+ end
32
+
33
+ def reduce_messages(exception)
34
+ case @strategy
35
+ when :truncate
36
+ truncate_messages(exception)
37
+ when :summarize
38
+ summarize_messages
39
+ else
40
+ raise "Unknown strategy: #{@options[:strategy]}"
41
+ end
42
+ end
43
+
44
+ def context
45
+ return if @context.nil? && @summary.nil?
46
+
47
+ Context.new([@context, @summary].compact.join("\n"))
48
+ end
49
+
50
+ private
51
+
52
+ def truncate_messages(exception)
53
+ raise exception if @messages.size == 1
54
+
55
+ token_overflow = exception.token_overflow
56
+
57
+ @messages = @messages.drop_while do |message|
58
+ proceed = token_overflow > -TOKEN_LEEWAY
59
+ token_overflow -= token_length(message.to_json, model_name, llm: @llm)
60
+
61
+ proceed
62
+ end
63
+ end
64
+
65
+ def summarize_messages
66
+ history = [@summary, @messages.to_json].compact.join("\n")
67
+ partitions = [history[0, history.size / 2], history[history.size / 2, history.size]]
68
+
69
+ @summary = partitions.map { |messages| @llm.summarize(text: messages.to_json) }.join("\n")
70
+
71
+ @messages = [@messages.last]
72
+ end
73
+
74
+ def partition_messages
75
+ end
76
+
77
+ def model_name
78
+ @llm.class::DEFAULTS[:chat_completion_model_name]
79
+ end
80
+
81
+ def token_length(content, model_name, options)
82
+ @llm.class::LENGTH_VALIDATOR.token_length(content, model_name, options)
83
+ end
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Conversation
5
+ class Message
6
+ attr_reader :content
7
+
8
+ ROLE_MAPPING = {
9
+ context: "system",
10
+ prompt: "user",
11
+ response: "assistant"
12
+ }
13
+
14
+ def initialize(content)
15
+ @content = content
16
+ end
17
+
18
+ def role
19
+ ROLE_MAPPING[type]
20
+ end
21
+
22
+ def to_s
23
+ content
24
+ end
25
+
26
+ def to_h
27
+ {
28
+ role: role,
29
+ content: content
30
+ }
31
+ end
32
+
33
+ def ==(other)
34
+ to_json == other.to_json
35
+ end
36
+
37
+ def to_json(options = {})
38
+ to_h.to_json
39
+ end
40
+
41
+ private
42
+
43
+ def type
44
+ self.class.to_s.split("::").last.downcase.to_sym
45
+ end
46
+ end
47
+ end
48
+ end
@@ -1,9 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- class HumanMessage < Message
5
- def type
6
- "human"
4
+ class Conversation
5
+ class Prompt < Message
7
6
  end
8
7
  end
9
8
  end
@@ -1,9 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- class SystemMessage < Message
5
- def type
6
- "system"
4
+ class Conversation
5
+ class Response < Message
7
6
  end
8
7
  end
9
8
  end
@@ -28,7 +28,7 @@ module Langchain
28
28
  @llm = llm
29
29
  @context = nil
30
30
  @examples = []
31
- @memory = ConversationMemory.new(
31
+ @memory = ::Langchain::Conversation::Memory.new(
32
32
  llm: llm,
33
33
  messages: options.delete(:messages) || [],
34
34
  strategy: options.delete(:memory_strategy)
@@ -44,48 +44,47 @@ module Langchain
44
44
  # Set the context of the conversation. Usually used to set the model's persona.
45
45
  # @param message [String] The context of the conversation
46
46
  def set_context(message)
47
- @memory.set_context SystemMessage.new(message)
47
+ @memory.set_context ::Langchain::Conversation::Context.new(message)
48
48
  end
49
49
 
50
50
  # Add examples to the conversation. Used to give the model a sense of the conversation.
51
- # @param examples [Array<AIMessage|HumanMessage>] The examples to add to the conversation
51
+ # @param examples [Array<Prompt|Response>] The examples to add to the conversation
52
52
  def add_examples(examples)
53
53
  @memory.add_examples examples
54
54
  end
55
55
 
56
56
  # Message the model with a prompt and return the response.
57
57
  # @param message [String] The prompt to message the model with
58
- # @return [AIMessage] The response from the model
58
+ # @return [Response] The response from the model
59
59
  def message(message)
60
- human_message = HumanMessage.new(message)
61
- @memory.append_message(human_message)
62
- ai_message = llm_response(human_message)
60
+ @memory.append_message ::Langchain::Conversation::Prompt.new(message)
61
+ ai_message = ::Langchain::Conversation::Response.new(llm_response.chat_completion)
63
62
  @memory.append_message(ai_message)
64
63
  ai_message
65
64
  end
66
65
 
67
66
  # Messages from conversation memory
68
- # @return [Array<AIMessage|HumanMessage>] The messages from the conversation memory
67
+ # @return [Array<Prompt|Response>] The messages from the conversation memory
69
68
  def messages
70
69
  @memory.messages
71
70
  end
72
71
 
73
72
  # Context from conversation memory
74
- # @return [SystemMessage] Context from conversation memory
73
+ # @return [Context] Context from conversation memory
75
74
  def context
76
75
  @memory.context
77
76
  end
78
77
 
79
78
  # Examples from conversation memory
80
- # @return [Array<AIMessage|HumanMessage>] Examples from the conversation memory
79
+ # @return [Array<Prompt|Response>] Examples from the conversation memory
81
80
  def examples
82
81
  @memory.examples
83
82
  end
84
83
 
85
84
  private
86
85
 
87
- def llm_response(prompt)
88
- @llm.chat(messages: @memory.messages, context: @memory.context, examples: @memory.examples, **@options, &@block)
86
+ def llm_response
87
+ @llm.chat(messages: @memory.messages.map(&:to_h), context: @memory.context&.to_s, examples: @memory.examples.map(&:to_h), **@options, &@block)
89
88
  rescue Langchain::Utils::TokenLength::TokenLimitExceeded => exception
90
89
  @memory.reduce_messages(exception)
91
90
  retry
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "ai21", "~> 0.2.1"
9
9
  #
10
10
  # Usage:
11
- # ai21 = Langchain::LLM::AI21.new(api_key:)
11
+ # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
12
12
  #
13
13
  class AI21 < Base
14
14
  DEFAULTS = {
@@ -30,7 +30,7 @@ module Langchain::LLM
30
30
  #
31
31
  # @param prompt [String] The prompt to generate a completion for
32
32
  # @param params [Hash] The parameters to pass to the API
33
- # @return [String] The completion
33
+ # @return [Langchain::LLM::AI21Response] The completion
34
34
  #
35
35
  def complete(prompt:, **params)
36
36
  parameters = complete_parameters params
@@ -38,7 +38,7 @@ module Langchain::LLM
38
38
  parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], client)
39
39
 
40
40
  response = client.complete(prompt, parameters)
41
- response.dig(:completions, 0, :data, :text)
41
+ Langchain::LLM::AI21Response.new response, model: parameters[:model]
42
42
  end
43
43
 
44
44
  #
@@ -51,6 +51,7 @@ module Langchain::LLM
51
51
  def summarize(text:, **params)
52
52
  response = client.summarize(text, "TEXT", params)
53
53
  response.dig(:summary)
54
+ # Should we update this to also return a Langchain::LLM::AI21Response?
54
55
  end
55
56
 
56
57
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "anthropic", "~> 0.1.0"
9
9
  #
10
10
  # Usage:
11
- # anthorpic = Langchain::LLM::Anthropic.new(api_key:)
11
+ # anthorpic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -32,7 +32,7 @@ module Langchain::LLM
32
32
  #
33
33
  # @param prompt [String] The prompt to generate a completion for
34
34
  # @param params [Hash] extra parameters passed to Anthropic::Client#complete
35
- # @return [String] The completion
35
+ # @return [Langchain::LLM::AnthropicResponse] The completion
36
36
  #
37
37
  def complete(prompt:, **params)
38
38
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -43,7 +43,7 @@ module Langchain::LLM
43
43
  # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
44
44
 
45
45
  response = client.complete(parameters: parameters)
46
- response.dig("completion")
46
+ Langchain::LLM::AnthropicResponse.new(response)
47
47
  end
48
48
 
49
49
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "cohere-ruby", "~> 0.9.6"
9
9
  #
10
10
  # Usage:
11
- # cohere = Langchain::LLM::Cohere.new(api_key: "YOUR_API_KEY")
11
+ # cohere = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
12
12
  #
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
@@ -30,14 +30,15 @@ module Langchain::LLM
30
30
  # Generate an embedding for a given text
31
31
  #
32
32
  # @param text [String] The text to generate an embedding for
33
- # @return [Hash] The embedding
33
+ # @return [Langchain::LLM::CohereResponse] Response object
34
34
  #
35
35
  def embed(text:)
36
36
  response = client.embed(
37
37
  texts: [text],
38
38
  model: @defaults[:embeddings_model_name]
39
39
  )
40
- response.dig("embeddings").first
40
+
41
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
41
42
  end
42
43
 
43
44
  #
@@ -45,7 +46,7 @@ module Langchain::LLM
45
46
  #
46
47
  # @param prompt [String] The prompt to generate a completion for
47
48
  # @param params[:stop_sequences]
48
- # @return [Hash] The completion
49
+ # @return [Langchain::LLM::CohereResponse] Response object
49
50
  #
50
51
  def complete(prompt:, **params)
51
52
  default_params = {
@@ -64,13 +65,13 @@ module Langchain::LLM
64
65
  default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
66
 
66
67
  response = client.generate(**default_params)
67
- response.dig("generations").first.dig("text")
68
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
68
69
  end
69
70
 
70
71
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
71
72
  def chat(...)
72
73
  response_text = complete(...)
73
- Langchain::AIMessage.new(response_text)
74
+ ::Langchain::Conversation::Response.new(response_text)
74
75
  end
75
76
 
76
77
  # Generate a summary in English for a given text