langchainrb 0.6.16 → 0.6.18

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +11 -0
  3. data/README.md +16 -1
  4. data/lib/langchain/active_record/hooks.rb +14 -0
  5. data/lib/langchain/agent/react_agent.rb +1 -1
  6. data/lib/langchain/agent/sql_query_agent.rb +2 -2
  7. data/lib/langchain/chunk.rb +16 -0
  8. data/lib/langchain/chunker/base.rb +7 -0
  9. data/lib/langchain/chunker/prompts/semantic_prompt_template.yml +8 -0
  10. data/lib/langchain/chunker/recursive_text.rb +5 -2
  11. data/lib/langchain/chunker/semantic.rb +52 -0
  12. data/lib/langchain/chunker/sentence.rb +4 -2
  13. data/lib/langchain/chunker/text.rb +5 -2
  14. data/lib/langchain/{ai_message.rb → conversation/context.rb} +2 -3
  15. data/lib/langchain/conversation/memory.rb +86 -0
  16. data/lib/langchain/conversation/message.rb +48 -0
  17. data/lib/langchain/{human_message.rb → conversation/prompt.rb} +2 -3
  18. data/lib/langchain/{system_message.rb → conversation/response.rb} +2 -3
  19. data/lib/langchain/conversation.rb +11 -12
  20. data/lib/langchain/llm/ai21.rb +4 -3
  21. data/lib/langchain/llm/anthropic.rb +3 -3
  22. data/lib/langchain/llm/cohere.rb +7 -6
  23. data/lib/langchain/llm/google_palm.rb +24 -20
  24. data/lib/langchain/llm/hugging_face.rb +4 -3
  25. data/lib/langchain/llm/llama_cpp.rb +1 -1
  26. data/lib/langchain/llm/ollama.rb +18 -6
  27. data/lib/langchain/llm/openai.rb +38 -41
  28. data/lib/langchain/llm/replicate.rb +7 -11
  29. data/lib/langchain/llm/response/ai21_response.rb +13 -0
  30. data/lib/langchain/llm/response/anthropic_response.rb +29 -0
  31. data/lib/langchain/llm/response/base_response.rb +79 -0
  32. data/lib/langchain/llm/response/cohere_response.rb +21 -0
  33. data/lib/langchain/llm/response/google_palm_response.rb +36 -0
  34. data/lib/langchain/llm/response/hugging_face_response.rb +13 -0
  35. data/lib/langchain/llm/response/ollama_response.rb +26 -0
  36. data/lib/langchain/llm/response/openai_response.rb +51 -0
  37. data/lib/langchain/llm/response/replicate_response.rb +28 -0
  38. data/lib/langchain/vectorsearch/base.rb +1 -1
  39. data/lib/langchain/vectorsearch/chroma.rb +11 -12
  40. data/lib/langchain/vectorsearch/hnswlib.rb +5 -5
  41. data/lib/langchain/vectorsearch/milvus.rb +2 -2
  42. data/lib/langchain/vectorsearch/pgvector.rb +3 -3
  43. data/lib/langchain/vectorsearch/pinecone.rb +10 -10
  44. data/lib/langchain/vectorsearch/qdrant.rb +5 -5
  45. data/lib/langchain/vectorsearch/weaviate.rb +6 -6
  46. data/lib/langchain/version.rb +1 -1
  47. data/lib/langchain.rb +3 -1
  48. metadata +23 -11
  49. data/lib/langchain/conversation_memory.rb +0 -84
  50. data/lib/langchain/message.rb +0 -35
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 36e0bec4ad6abfd9077c9e7f2d6166ba99acb7dc3859749ee6facfb9409e6379
4
- data.tar.gz: 6bd8d3de4f1d31b718381fcef1c21a8b417b2bd8483d7fdc2610cfda3b60a50e
3
+ metadata.gz: 437c6387ded139ed1a513414bfb7242cdbadf1ba6526c7a89346aa2fa9490fc2
4
+ data.tar.gz: dd6f437a4bbc4807a16631dd790f66c9de4e9456011b2c4f84302fe3fab1377b
5
5
  SHA512:
6
- metadata.gz: ed7be8f193d44075f701622fd991127ab32580293fb6d1ab7ccc096eeff8704312ad34cdb7a4cfd09cf8879116ede17a5b017fe15851b9ee78cb159b7e8d8b59
7
- data.tar.gz: f70d7a3707ed7fce123c2f9158c338cda3aa38a46abf5598f7d05c6ccd63d5a16a37ba10ff0a7a0a4cd17c0c2aeb2f07a07842a41f16322c48c7c9bae522dda4
6
+ metadata.gz: 24748539de50dfa816fdb71173ef00a6b04f9737f32926fca919865a49b9812dd9f1fdb286c361c98e33cc994f67e8988ab688bfdf6bf3020d954eb0c791177c
7
+ data.tar.gz: 283b10460187cada7485e08a19c89e7485925ab2f73a5ad51b06a72e8fd9ee1600ddac9d000f13c0c1af13f6defece9fdcc272489d0df803f94da96fe1c76cfd
data/CHANGELOG.md CHANGED
@@ -1,5 +1,16 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.6.18] - 2023-10-16
4
+ - Introduce `Langchain::LLM::Response`` object
5
+ - Introduce `Langchain::Chunk` object
6
+ - Add the ask() method to the Langchain::ActiveRecord::Hooks
7
+
8
+ ## [0.6.17] - 2023-10-10
9
+ - Bump weaviate and chroma-db deps
10
+ - `Langchain::Chunker::Semantic` chunker
11
+ - Re-structure Conversations class
12
+ - Bug fixes
13
+
3
14
  ## [0.6.16] - 2023-10-02
4
15
  - HyDE-style similarity search
5
16
  - `Langchain::Chunker::Sentence` chunker
data/README.md CHANGED
@@ -59,7 +59,7 @@ client = Langchain::Vectorsearch::Weaviate.new(
59
59
  )
60
60
 
61
61
  # You can instantiate any other supported vector search database:
62
- client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.3.0"`
62
+ client = Langchain::Vectorsearch::Chroma.new(...) # `gem "chroma-db", "~> 0.6.0"`
63
63
  client = Langchain::Vectorsearch::Hnswlib.new(...) # `gem "hnswlib", "~> 0.8.1"`
64
64
  client = Langchain::Vectorsearch::Milvus.new(...) # `gem "milvus", "~> 0.9.2"`
65
65
  client = Langchain::Vectorsearch::Pinecone.new(...) # `gem "pinecone", "~> 0.1.6"`
@@ -128,6 +128,21 @@ class Product < ActiveRecord::Base
128
128
  end
129
129
  ```
130
130
 
131
+ ### Exposed ActiveRecord methods
132
+ ```ruby
133
+ # Retrieve similar products based on the query string passed in
134
+ Product.similarity_search(
135
+ query:,
136
+ k: # number of results to be retrieved
137
+ )
138
+ ```
139
+ ```ruby
140
+ # Q&A-style querying based on the question passed in
141
+ Product.ask(
142
+ question:
143
+ )
144
+ ```
145
+
131
146
  Additional info [here](https://github.com/andreibondarev/langchainrb/blob/main/lib/langchain/active_record/hooks.rb#L10-L38).
132
147
 
133
148
  ### Using Standalone LLMs 🗣️
@@ -92,6 +92,20 @@ module Langchain
92
92
  ids = records.map { |record| record.dig("id") || record.dig("__id") }
93
93
  where(id: ids)
94
94
  end
95
+
96
+ # Ask a question and return the answer
97
+ #
98
+ # @param question [String] The question to ask
99
+ # @param k [Integer] The number of results to have in context
100
+ # @yield [String] Stream responses back one String at a time
101
+ # @return [String] The answer to the question
102
+ def ask(question:, k: 4, &block)
103
+ class_variable_get(:@@provider).ask(
104
+ question: question,
105
+ k: k,
106
+ &block
107
+ )
108
+ end
95
109
  end
96
110
  end
97
111
  end
@@ -58,7 +58,7 @@ module Langchain::Agent
58
58
  max_iterations.times do
59
59
  Langchain.logger.info("Sending the prompt to the #{llm.class} LLM", for: self.class)
60
60
 
61
- response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"])
61
+ response = llm.complete(prompt: prompt, stop_sequences: ["Observation:"]).completion
62
62
 
63
63
  # Append the response to the prompt
64
64
  prompt += response
@@ -27,7 +27,7 @@ module Langchain::Agent
27
27
 
28
28
  # Get the SQL string to execute
29
29
  Langchain.logger.info("Passing the inital prompt to the #{llm.class} LLM", for: self.class)
30
- sql_string = llm.complete(prompt: prompt)
30
+ sql_string = llm.complete(prompt: prompt).completion
31
31
 
32
32
  # Execute the SQL string and collect the results
33
33
  Langchain.logger.info("Passing the SQL to the Database: #{sql_string}", for: self.class)
@@ -36,7 +36,7 @@ module Langchain::Agent
36
36
  # Pass the results and get the LLM to synthesize the answer to the question
37
37
  Langchain.logger.info("Passing the synthesize prompt to the #{llm.class} LLM with results: #{results}", for: self.class)
38
38
  prompt2 = create_prompt_for_answer(question: question, sql_query: sql_string, results: results)
39
- llm.complete(prompt: prompt2)
39
+ llm.complete(prompt: prompt2).completion
40
40
  end
41
41
 
42
42
  private
@@ -0,0 +1,16 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Chunk
5
+ # The chunking process is the process of splitting a document into smaller chunks and creating instances of Langchain::Chunk
6
+
7
+ attr_reader :text
8
+
9
+ # Initialize a new chunk
10
+ # @param [String] text
11
+ # @return [Langchain::Chunk]
12
+ def initialize(text:)
13
+ @text = text
14
+ end
15
+ end
16
+ end
@@ -8,8 +8,15 @@ module Langchain
8
8
  #
9
9
  # == Available chunkers
10
10
  #
11
+ # - {Langchain::Chunker::RecursiveText}
11
12
  # - {Langchain::Chunker::Text}
13
+ # - {Langchain::Chunker::Semantic}
14
+ # - {Langchain::Chunker::Sentence}
12
15
  class Base
16
+ # @return [Array<Langchain::Chunk>]
17
+ def chunks
18
+ raise NotImplementedError
19
+ end
13
20
  end
14
21
  end
15
22
  end
@@ -0,0 +1,8 @@
1
+ _type: prompt
2
+ input_variables:
3
+ - text
4
+ template: |
5
+ Please split the following text by topics.
6
+ Output only the paragraphs delimited by "---":
7
+
8
+ {text}
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separators = separators
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::RecursiveCharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separators: separators
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -0,0 +1,52 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Chunker
5
+ #
6
+ # LLM-powered semantic chunker.
7
+ # Semantic chunking is a technique of splitting texts by their semantic meaning, e.g.: themes, topics, and ideas.
8
+ # We use an LLM to accomplish this. The Anthropic LLM is highly recommended for this task as it has the longest context window (100k tokens).
9
+ #
10
+ # Usage:
11
+ # Langchain::Chunker::Semantic.new(
12
+ # text,
13
+ # llm: Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
14
+ # ).chunks
15
+ #
16
+ class Semantic < Base
17
+ attr_reader :text, :llm, :prompt_template
18
+ # @param [Langchain::LLM::Base] Langchain::LLM::* instance
19
+ # @param [Langchain::Prompt::PromptTemplate] Optional custom prompt template
20
+ def initialize(text, llm:, prompt_template: nil)
21
+ @text = text
22
+ @llm = llm
23
+ @prompt_template = prompt_template || default_prompt_template
24
+ end
25
+
26
+ # @return [Array<Langchain::Chunk>]
27
+ def chunks
28
+ prompt = prompt_template.format(text: text)
29
+
30
+ # Replace static 50k limit with dynamic limit based on text length (max_tokens_to_sample)
31
+ completion = llm.complete(prompt: prompt, max_tokens_to_sample: 50000)
32
+ completion
33
+ .gsub("Here are the paragraphs split by topic:\n\n", "")
34
+ .split("---")
35
+ .map(&:strip)
36
+ .reject(&:empty?)
37
+ .map do |chunk|
38
+ Langchain::Chunk.new(text: chunk)
39
+ end
40
+ end
41
+
42
+ private
43
+
44
+ # @return [Langchain::Prompt::PromptTemplate] Default prompt template for semantic chunking
45
+ def default_prompt_template
46
+ Langchain::Prompt.load_from_path(
47
+ file_path: Langchain.root.join("langchain/chunker/prompts/semantic_prompt_template.yml")
48
+ )
49
+ end
50
+ end
51
+ end
52
+ end
@@ -19,10 +19,12 @@ module Langchain
19
19
  @text = text
20
20
  end
21
21
 
22
- # @return [Array<String>]
22
+ # @return [Array<Langchain::Chunk>]
23
23
  def chunks
24
24
  ps = PragmaticSegmenter::Segmenter.new(text: text)
25
- ps.segment
25
+ ps.segment.map do |chunk|
26
+ Langchain::Chunk.new(text: chunk)
27
+ end
26
28
  end
27
29
  end
28
30
  end
@@ -24,14 +24,17 @@ module Langchain
24
24
  @separator = separator
25
25
  end
26
26
 
27
- # @return [Array<String>]
27
+ # @return [Array<Langchain::Chunk>]
28
28
  def chunks
29
29
  splitter = Baran::CharacterTextSplitter.new(
30
30
  chunk_size: chunk_size,
31
31
  chunk_overlap: chunk_overlap,
32
32
  separator: separator
33
33
  )
34
- splitter.chunks(text)
34
+
35
+ splitter.chunks(text).map do |chunk|
36
+ Langchain::Chunk.new(text: chunk[:text])
37
+ end
35
38
  end
36
39
  end
37
40
  end
@@ -1,9 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- class AIMessage < Message
5
- def type
6
- "ai"
4
+ class Conversation
5
+ class Context < Message
7
6
  end
8
7
  end
9
8
  end
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Conversation
5
+ class Memory
6
+ attr_reader :examples, :messages
7
+
8
+ # The least number of tokens we want to be under the limit by
9
+ TOKEN_LEEWAY = 20
10
+
11
+ def initialize(llm:, messages: [], **options)
12
+ @llm = llm
13
+ @context = nil
14
+ @summary = nil
15
+ @examples = []
16
+ @messages = messages
17
+ @strategy = options.delete(:strategy) || :truncate
18
+ @options = options
19
+ end
20
+
21
+ def set_context(message)
22
+ @context = message
23
+ end
24
+
25
+ def add_examples(examples)
26
+ @examples.concat examples
27
+ end
28
+
29
+ def append_message(message)
30
+ @messages.append(message)
31
+ end
32
+
33
+ def reduce_messages(exception)
34
+ case @strategy
35
+ when :truncate
36
+ truncate_messages(exception)
37
+ when :summarize
38
+ summarize_messages
39
+ else
40
+ raise "Unknown strategy: #{@options[:strategy]}"
41
+ end
42
+ end
43
+
44
+ def context
45
+ return if @context.nil? && @summary.nil?
46
+
47
+ Context.new([@context, @summary].compact.join("\n"))
48
+ end
49
+
50
+ private
51
+
52
+ def truncate_messages(exception)
53
+ raise exception if @messages.size == 1
54
+
55
+ token_overflow = exception.token_overflow
56
+
57
+ @messages = @messages.drop_while do |message|
58
+ proceed = token_overflow > -TOKEN_LEEWAY
59
+ token_overflow -= token_length(message.to_json, model_name, llm: @llm)
60
+
61
+ proceed
62
+ end
63
+ end
64
+
65
+ def summarize_messages
66
+ history = [@summary, @messages.to_json].compact.join("\n")
67
+ partitions = [history[0, history.size / 2], history[history.size / 2, history.size]]
68
+
69
+ @summary = partitions.map { |messages| @llm.summarize(text: messages.to_json) }.join("\n")
70
+
71
+ @messages = [@messages.last]
72
+ end
73
+
74
+ def partition_messages
75
+ end
76
+
77
+ def model_name
78
+ @llm.class::DEFAULTS[:chat_completion_model_name]
79
+ end
80
+
81
+ def token_length(content, model_name, options)
82
+ @llm.class::LENGTH_VALIDATOR.token_length(content, model_name, options)
83
+ end
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ class Conversation
5
+ class Message
6
+ attr_reader :content
7
+
8
+ ROLE_MAPPING = {
9
+ context: "system",
10
+ prompt: "user",
11
+ response: "assistant"
12
+ }
13
+
14
+ def initialize(content)
15
+ @content = content
16
+ end
17
+
18
+ def role
19
+ ROLE_MAPPING[type]
20
+ end
21
+
22
+ def to_s
23
+ content
24
+ end
25
+
26
+ def to_h
27
+ {
28
+ role: role,
29
+ content: content
30
+ }
31
+ end
32
+
33
+ def ==(other)
34
+ to_json == other.to_json
35
+ end
36
+
37
+ def to_json(options = {})
38
+ to_h.to_json
39
+ end
40
+
41
+ private
42
+
43
+ def type
44
+ self.class.to_s.split("::").last.downcase.to_sym
45
+ end
46
+ end
47
+ end
48
+ end
@@ -1,9 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- class HumanMessage < Message
5
- def type
6
- "human"
4
+ class Conversation
5
+ class Prompt < Message
7
6
  end
8
7
  end
9
8
  end
@@ -1,9 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- class SystemMessage < Message
5
- def type
6
- "system"
4
+ class Conversation
5
+ class Response < Message
7
6
  end
8
7
  end
9
8
  end
@@ -28,7 +28,7 @@ module Langchain
28
28
  @llm = llm
29
29
  @context = nil
30
30
  @examples = []
31
- @memory = ConversationMemory.new(
31
+ @memory = ::Langchain::Conversation::Memory.new(
32
32
  llm: llm,
33
33
  messages: options.delete(:messages) || [],
34
34
  strategy: options.delete(:memory_strategy)
@@ -44,48 +44,47 @@ module Langchain
44
44
  # Set the context of the conversation. Usually used to set the model's persona.
45
45
  # @param message [String] The context of the conversation
46
46
  def set_context(message)
47
- @memory.set_context SystemMessage.new(message)
47
+ @memory.set_context ::Langchain::Conversation::Context.new(message)
48
48
  end
49
49
 
50
50
  # Add examples to the conversation. Used to give the model a sense of the conversation.
51
- # @param examples [Array<AIMessage|HumanMessage>] The examples to add to the conversation
51
+ # @param examples [Array<Prompt|Response>] The examples to add to the conversation
52
52
  def add_examples(examples)
53
53
  @memory.add_examples examples
54
54
  end
55
55
 
56
56
  # Message the model with a prompt and return the response.
57
57
  # @param message [String] The prompt to message the model with
58
- # @return [AIMessage] The response from the model
58
+ # @return [Response] The response from the model
59
59
  def message(message)
60
- human_message = HumanMessage.new(message)
61
- @memory.append_message(human_message)
62
- ai_message = llm_response(human_message)
60
+ @memory.append_message ::Langchain::Conversation::Prompt.new(message)
61
+ ai_message = ::Langchain::Conversation::Response.new(llm_response.chat_completion)
63
62
  @memory.append_message(ai_message)
64
63
  ai_message
65
64
  end
66
65
 
67
66
  # Messages from conversation memory
68
- # @return [Array<AIMessage|HumanMessage>] The messages from the conversation memory
67
+ # @return [Array<Prompt|Response>] The messages from the conversation memory
69
68
  def messages
70
69
  @memory.messages
71
70
  end
72
71
 
73
72
  # Context from conversation memory
74
- # @return [SystemMessage] Context from conversation memory
73
+ # @return [Context] Context from conversation memory
75
74
  def context
76
75
  @memory.context
77
76
  end
78
77
 
79
78
  # Examples from conversation memory
80
- # @return [Array<AIMessage|HumanMessage>] Examples from the conversation memory
79
+ # @return [Array<Prompt|Response>] Examples from the conversation memory
81
80
  def examples
82
81
  @memory.examples
83
82
  end
84
83
 
85
84
  private
86
85
 
87
- def llm_response(prompt)
88
- @llm.chat(messages: @memory.messages, context: @memory.context, examples: @memory.examples, **@options, &@block)
86
+ def llm_response
87
+ @llm.chat(messages: @memory.messages.map(&:to_h), context: @memory.context&.to_s, examples: @memory.examples.map(&:to_h), **@options, &@block)
89
88
  rescue Langchain::Utils::TokenLength::TokenLimitExceeded => exception
90
89
  @memory.reduce_messages(exception)
91
90
  retry
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "ai21", "~> 0.2.1"
9
9
  #
10
10
  # Usage:
11
- # ai21 = Langchain::LLM::AI21.new(api_key:)
11
+ # ai21 = Langchain::LLM::AI21.new(api_key: ENV["AI21_API_KEY"])
12
12
  #
13
13
  class AI21 < Base
14
14
  DEFAULTS = {
@@ -30,7 +30,7 @@ module Langchain::LLM
30
30
  #
31
31
  # @param prompt [String] The prompt to generate a completion for
32
32
  # @param params [Hash] The parameters to pass to the API
33
- # @return [String] The completion
33
+ # @return [Langchain::LLM::AI21Response] The completion
34
34
  #
35
35
  def complete(prompt:, **params)
36
36
  parameters = complete_parameters params
@@ -38,7 +38,7 @@ module Langchain::LLM
38
38
  parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], client)
39
39
 
40
40
  response = client.complete(prompt, parameters)
41
- response.dig(:completions, 0, :data, :text)
41
+ Langchain::LLM::AI21Response.new response, model: parameters[:model]
42
42
  end
43
43
 
44
44
  #
@@ -51,6 +51,7 @@ module Langchain::LLM
51
51
  def summarize(text:, **params)
52
52
  response = client.summarize(text, "TEXT", params)
53
53
  response.dig(:summary)
54
+ # Should we update this to also return a Langchain::LLM::AI21Response?
54
55
  end
55
56
 
56
57
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "anthropic", "~> 0.1.0"
9
9
  #
10
10
  # Usage:
11
- # anthorpic = Langchain::LLM::Anthropic.new(api_key:)
11
+ # anthorpic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -32,7 +32,7 @@ module Langchain::LLM
32
32
  #
33
33
  # @param prompt [String] The prompt to generate a completion for
34
34
  # @param params [Hash] extra parameters passed to Anthropic::Client#complete
35
- # @return [String] The completion
35
+ # @return [Langchain::LLM::AnthropicResponse] The completion
36
36
  #
37
37
  def complete(prompt:, **params)
38
38
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -43,7 +43,7 @@ module Langchain::LLM
43
43
  # parameters[:max_tokens_to_sample] = validate_max_tokens(prompt, parameters[:completion_model_name])
44
44
 
45
45
  response = client.complete(parameters: parameters)
46
- response.dig("completion")
46
+ Langchain::LLM::AnthropicResponse.new(response)
47
47
  end
48
48
 
49
49
  private
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "cohere-ruby", "~> 0.9.6"
9
9
  #
10
10
  # Usage:
11
- # cohere = Langchain::LLM::Cohere.new(api_key: "YOUR_API_KEY")
11
+ # cohere = Langchain::LLM::Cohere.new(api_key: ENV["COHERE_API_KEY"])
12
12
  #
13
13
  class Cohere < Base
14
14
  DEFAULTS = {
@@ -30,14 +30,15 @@ module Langchain::LLM
30
30
  # Generate an embedding for a given text
31
31
  #
32
32
  # @param text [String] The text to generate an embedding for
33
- # @return [Hash] The embedding
33
+ # @return [Langchain::LLM::CohereResponse] Response object
34
34
  #
35
35
  def embed(text:)
36
36
  response = client.embed(
37
37
  texts: [text],
38
38
  model: @defaults[:embeddings_model_name]
39
39
  )
40
- response.dig("embeddings").first
40
+
41
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:embeddings_model_name]
41
42
  end
42
43
 
43
44
  #
@@ -45,7 +46,7 @@ module Langchain::LLM
45
46
  #
46
47
  # @param prompt [String] The prompt to generate a completion for
47
48
  # @param params[:stop_sequences]
48
- # @return [Hash] The completion
49
+ # @return [Langchain::LLM::CohereResponse] Response object
49
50
  #
50
51
  def complete(prompt:, **params)
51
52
  default_params = {
@@ -64,13 +65,13 @@ module Langchain::LLM
64
65
  default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], client)
65
66
 
66
67
  response = client.generate(**default_params)
67
- response.dig("generations").first.dig("text")
68
+ Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
68
69
  end
69
70
 
70
71
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
71
72
  def chat(...)
72
73
  response_text = complete(...)
73
- Langchain::AIMessage.new(response_text)
74
+ ::Langchain::Conversation::Response.new(response_text)
74
75
  end
75
76
 
76
77
  # Generate a summary in English for a given text