langchainrb 0.6.16 → 0.6.18

Sign up to get free protection for your applications and to get access to all the features.
Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +11 -0
  3. data/README.md +16 -1
  4. data/lib/langchain/active_record/hooks.rb +14 -0
  5. data/lib/langchain/agent/react_agent.rb +1 -1
  6. data/lib/langchain/agent/sql_query_agent.rb +2 -2
  7. data/lib/langchain/chunk.rb +16 -0
  8. data/lib/langchain/chunker/base.rb +7 -0
  9. data/lib/langchain/chunker/prompts/semantic_prompt_template.yml +8 -0
  10. data/lib/langchain/chunker/recursive_text.rb +5 -2
  11. data/lib/langchain/chunker/semantic.rb +52 -0
  12. data/lib/langchain/chunker/sentence.rb +4 -2
  13. data/lib/langchain/chunker/text.rb +5 -2
  14. data/lib/langchain/{ai_message.rb → conversation/context.rb} +2 -3
  15. data/lib/langchain/conversation/memory.rb +86 -0
  16. data/lib/langchain/conversation/message.rb +48 -0
  17. data/lib/langchain/{human_message.rb → conversation/prompt.rb} +2 -3
  18. data/lib/langchain/{system_message.rb → conversation/response.rb} +2 -3
  19. data/lib/langchain/conversation.rb +11 -12
  20. data/lib/langchain/llm/ai21.rb +4 -3
  21. data/lib/langchain/llm/anthropic.rb +3 -3
  22. data/lib/langchain/llm/cohere.rb +7 -6
  23. data/lib/langchain/llm/google_palm.rb +24 -20
  24. data/lib/langchain/llm/hugging_face.rb +4 -3
  25. data/lib/langchain/llm/llama_cpp.rb +1 -1
  26. data/lib/langchain/llm/ollama.rb +18 -6
  27. data/lib/langchain/llm/openai.rb +38 -41
  28. data/lib/langchain/llm/replicate.rb +7 -11
  29. data/lib/langchain/llm/response/ai21_response.rb +13 -0
  30. data/lib/langchain/llm/response/anthropic_response.rb +29 -0
  31. data/lib/langchain/llm/response/base_response.rb +79 -0
  32. data/lib/langchain/llm/response/cohere_response.rb +21 -0
  33. data/lib/langchain/llm/response/google_palm_response.rb +36 -0
  34. data/lib/langchain/llm/response/hugging_face_response.rb +13 -0
  35. data/lib/langchain/llm/response/ollama_response.rb +26 -0
  36. data/lib/langchain/llm/response/openai_response.rb +51 -0
  37. data/lib/langchain/llm/response/replicate_response.rb +28 -0
  38. data/lib/langchain/vectorsearch/base.rb +1 -1
  39. data/lib/langchain/vectorsearch/chroma.rb +11 -12
  40. data/lib/langchain/vectorsearch/hnswlib.rb +5 -5
  41. data/lib/langchain/vectorsearch/milvus.rb +2 -2
  42. data/lib/langchain/vectorsearch/pgvector.rb +3 -3
  43. data/lib/langchain/vectorsearch/pinecone.rb +10 -10
  44. data/lib/langchain/vectorsearch/qdrant.rb +5 -5
  45. data/lib/langchain/vectorsearch/weaviate.rb +6 -6
  46. data/lib/langchain/version.rb +1 -1
  47. data/lib/langchain.rb +3 -1
  48. metadata +23 -11
  49. data/lib/langchain/conversation_memory.rb +0 -84
  50. data/lib/langchain/message.rb +0 -35
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "google_palm_api", "~> 0.1.3"
9
9
  #
10
10
  # Usage:
11
- # google_palm = Langchain::LLM::GooglePalm.new(api_key: "YOUR_API_KEY")
11
+ # google_palm = Langchain::LLM::GooglePalm.new(api_key: ENV["GOOGLE_PALM_API_KEY"])
12
12
  #
13
13
  class GooglePalm < Base
14
14
  DEFAULTS = {
@@ -20,7 +20,7 @@ module Langchain::LLM
20
20
  }.freeze
21
21
  LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
22
22
  ROLE_MAPPING = {
23
- "human" => "user"
23
+ "assistant" => "ai"
24
24
  }
25
25
 
26
26
  def initialize(api_key:, default_options: {})
@@ -34,13 +34,13 @@ module Langchain::LLM
34
34
  # Generate an embedding for a given text
35
35
  #
36
36
  # @param text [String] The text to generate an embedding for
37
- # @return [Array] The embedding
37
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
38
38
  #
39
39
  def embed(text:)
40
- response = client.embed(
41
- text: text
42
- )
43
- response.dig("embedding", "value")
40
+ response = client.embed(text: text)
41
+
42
+ Langchain::LLM::GooglePalmResponse.new response,
43
+ model: @defaults[:embeddings_model_name]
44
44
  end
45
45
 
46
46
  #
@@ -48,7 +48,7 @@ module Langchain::LLM
48
48
  #
49
49
  # @param prompt [String] The prompt to generate a completion for
50
50
  # @param params extra parameters passed to GooglePalmAPI::Client#generate_text
51
- # @return [String] The completion
51
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
52
52
  #
53
53
  def complete(prompt:, **params)
54
54
  default_params = {
@@ -68,18 +68,20 @@ module Langchain::LLM
68
68
  default_params.merge!(params)
69
69
 
70
70
  response = client.generate_text(**default_params)
71
- response.dig("candidates", 0, "output")
71
+
72
+ Langchain::LLM::GooglePalmResponse.new response,
73
+ model: default_params[:model]
72
74
  end
73
75
 
74
76
  #
75
77
  # Generate a chat completion for a given prompt
76
78
  #
77
- # @param prompt [HumanMessage] The prompt to generate a chat completion for
78
- # @param messages [Array<AIMessage|HumanMessage>] The messages that have been sent in the conversation
79
- # @param context [SystemMessage] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
80
- # @param examples [Array<AIMessage|HumanMessage>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
79
+ # @param prompt [String] The prompt to generate a chat completion for
80
+ # @param messages [Array<Hash>] The messages that have been sent in the conversation
81
+ # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
82
+ # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
81
83
  # @param options [Hash] extra parameters passed to GooglePalmAPI::Client#generate_chat_message
82
- # @return [AIMessage] The chat completion
84
+ # @return [Langchain::LLM::GooglePalmResponse] Response object
83
85
  #
84
86
  def chat(prompt: "", messages: [], context: "", examples: [], **options)
85
87
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
@@ -87,7 +89,7 @@ module Langchain::LLM
87
89
  default_params = {
88
90
  temperature: @defaults[:temperature],
89
91
  model: @defaults[:chat_completion_model_name],
90
- context: context.to_s,
92
+ context: context,
91
93
  messages: compose_chat_messages(prompt: prompt, messages: messages),
92
94
  examples: compose_examples(examples)
93
95
  }
@@ -108,7 +110,9 @@ module Langchain::LLM
108
110
  response = client.generate_chat_message(**default_params)
109
111
  raise "GooglePalm API returned an error: #{response}" if response.dig("error")
110
112
 
111
- Langchain::AIMessage.new(response.dig("candidates", 0, "content"))
113
+ Langchain::LLM::GooglePalmResponse.new response,
114
+ model: default_params[:model]
115
+ # TODO: Pass in prompt_tokens: prompt_tokens
112
116
  end
113
117
 
114
118
  #
@@ -150,8 +154,8 @@ module Langchain::LLM
150
154
  def compose_examples(examples)
151
155
  examples.each_slice(2).map do |example|
152
156
  {
153
- input: {content: example.first.content},
154
- output: {content: example.last.content}
157
+ input: {content: example.first[:content]},
158
+ output: {content: example.last[:content]}
155
159
  }
156
160
  end
157
161
  end
@@ -159,8 +163,8 @@ module Langchain::LLM
159
163
  def transform_messages(messages)
160
164
  messages.map do |message|
161
165
  {
162
- author: ROLE_MAPPING.fetch(message.type, message.type),
163
- content: message.content
166
+ author: ROLE_MAPPING.fetch(message[:role], message[:role]),
167
+ content: message[:content]
164
168
  }
165
169
  end
166
170
  end
@@ -8,7 +8,7 @@ module Langchain::LLM
8
8
  # gem "hugging-face", "~> 0.3.4"
9
9
  #
10
10
  # Usage:
11
- # hf = Langchain::LLM::HuggingFace.new(api_key: "YOUR_API_KEY")
11
+ # hf = Langchain::LLM::HuggingFace.new(api_key: ENV["HUGGING_FACE_API_KEY"])
12
12
  #
13
13
  class HuggingFace < Base
14
14
  # The gem does not currently accept other models:
@@ -34,13 +34,14 @@ module Langchain::LLM
34
34
  # Generate an embedding for a given text
35
35
  #
36
36
  # @param text [String] The text to embed
37
- # @return [Array] The embedding
37
+ # @return [Langchain::LLM::HuggingFaceResponse] Response object
38
38
  #
39
39
  def embed(text:)
40
- client.embedding(
40
+ response = client.embedding(
41
41
  input: text,
42
42
  model: DEFAULTS[:embeddings_model_name]
43
43
  )
44
+ Langchain::LLM::HuggingFaceResponse.new(response, model: DEFAULTS[:embeddings_model_name])
44
45
  end
45
46
  end
46
47
  end
@@ -34,7 +34,7 @@ module Langchain::LLM
34
34
 
35
35
  # @param text [String] The text to embed
36
36
  # @param n_threads [Integer] The number of CPU threads to use
37
- # @return [Array] The embedding
37
+ # @return [Array<Float>] The embedding
38
38
  def embed(text:, n_threads: nil)
39
39
  # contexts are kinda stateful when it comes to embeddings, so allocate one each time
40
40
  context = embedding_context
@@ -22,18 +22,23 @@ module Langchain::LLM
22
22
  @url = url
23
23
  end
24
24
 
25
+ #
25
26
  # Generate the completion for a given prompt
27
+ #
26
28
  # @param prompt [String] The prompt to complete
27
29
  # @param model [String] The model to use
28
30
  # @param options [Hash] The options to use (https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
29
- # @return [String] The completed prompt
31
+ # @return [Langchain::LLM::OllamaResponse] Response object
32
+ #
30
33
  def complete(prompt:, model: nil, **options)
31
34
  response = +""
32
35
 
36
+ model_name = model || DEFAULTS[:completion_model_name]
37
+
33
38
  client.post("api/generate") do |req|
34
39
  req.body = {}
35
40
  req.body["prompt"] = prompt
36
- req.body["model"] = model || DEFAULTS[:completion_model_name]
41
+ req.body["model"] = model_name
37
42
 
38
43
  req.body["options"] = options if options.any?
39
44
 
@@ -47,27 +52,34 @@ module Langchain::LLM
47
52
  end
48
53
  end
49
54
 
50
- response
55
+ Langchain::LLM::OllamaResponse.new(response, model: model_name)
51
56
  end
52
57
 
58
+ #
53
59
  # Generate an embedding for a given text
60
+ #
54
61
  # @param text [String] The text to generate an embedding for
55
62
  # @param model [String] The model to use
56
- # @param options [Hash] The options to use (
63
+ # @param options [Hash] The options to use
64
+ # @return [Langchain::LLM::OllamaResponse] Response object
65
+ #
57
66
  def embed(text:, model: nil, **options)
67
+ model_name = model || DEFAULTS[:embeddings_model_name]
68
+
58
69
  response = client.post("api/embeddings") do |req|
59
70
  req.body = {}
60
71
  req.body["prompt"] = text
61
- req.body["model"] = model || DEFAULTS[:embeddings_model_name]
72
+ req.body["model"] = model_name
62
73
 
63
74
  req.body["options"] = options if options.any?
64
75
  end
65
76
 
66
- response.body.dig("embedding")
77
+ Langchain::LLM::OllamaResponse.new(response.body, model: model_name)
67
78
  end
68
79
 
69
80
  private
70
81
 
82
+ # @return [Faraday::Connection] Faraday client
71
83
  def client
72
84
  @client ||= Faraday.new(url: url) do |conn|
73
85
  conn.request :json
@@ -11,6 +11,7 @@ module Langchain::LLM
11
11
  #
12
12
  class OpenAI < Base
13
13
  DEFAULTS = {
14
+ n: 1,
14
15
  temperature: 0.0,
15
16
  completion_model_name: "gpt-3.5-turbo",
16
17
  chat_completion_model_name: "gpt-3.5-turbo",
@@ -26,10 +27,6 @@ module Langchain::LLM
26
27
  ].freeze
27
28
 
28
29
  LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
29
- ROLE_MAPPING = {
30
- "ai" => "assistant",
31
- "human" => "user"
32
- }
33
30
 
34
31
  attr_accessor :functions
35
32
 
@@ -45,7 +42,7 @@ module Langchain::LLM
45
42
  #
46
43
  # @param text [String] The text to generate an embedding for
47
44
  # @param params extra parameters passed to OpenAI::Client#embeddings
48
- # @return [Array] The embedding
45
+ # @return [Langchain::LLM::OpenAIResponse] Response object
49
46
  #
50
47
  def embed(text:, **params)
51
48
  parameters = {model: @defaults[:embeddings_model_name], input: text}
@@ -56,7 +53,7 @@ module Langchain::LLM
56
53
  client.embeddings(parameters: parameters.merge(params))
57
54
  end
58
55
 
59
- response.dig("data").first.dig("embedding")
56
+ Langchain::LLM::OpenAIResponse.new(response)
60
57
  end
61
58
 
62
59
  #
@@ -64,7 +61,7 @@ module Langchain::LLM
64
61
  #
65
62
  # @param prompt [String] The prompt to generate a completion for
66
63
  # @param params extra parameters passed to OpenAI::Client#complete
67
- # @return [String] The completion
64
+ # @return [Langchain::LLM::Response::OpenaAI] Response object
68
65
  #
69
66
  def complete(prompt:, **params)
70
67
  parameters = compose_parameters @defaults[:completion_model_name], params
@@ -78,7 +75,7 @@ module Langchain::LLM
78
75
  client.chat(parameters: parameters)
79
76
  end
80
77
 
81
- response.dig("choices", 0, "message", "content")
78
+ Langchain::LLM::OpenAIResponse.new(response)
82
79
  end
83
80
 
84
81
  #
@@ -117,18 +114,18 @@ module Langchain::LLM
117
114
  # },
118
115
  # ]
119
116
  #
120
- # @param prompt [HumanMessage] The prompt to generate a chat completion for
121
- # @param messages [Array<AIMessage|HumanMessage>] The messages that have been sent in the conversation
122
- # @param context [SystemMessage] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
123
- # @param examples [Array<AIMessage|HumanMessage>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
117
+ # @param prompt [String] The prompt to generate a chat completion for
118
+ # @param messages [Array<Hash>] The messages that have been sent in the conversation
119
+ # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
120
+ # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
124
121
  # @param options [Hash] extra parameters passed to OpenAI::Client#chat
125
- # @yield [AIMessage] Stream responses back one String at a time
126
- # @return [AIMessage] The chat completion
122
+ # @yield [Hash] Stream responses back one token at a time
123
+ # @return [Langchain::LLM::OpenAIResponse] Response object
127
124
  #
128
- def chat(prompt: "", messages: [], context: "", examples: [], **options)
125
+ def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
129
126
  raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
130
127
 
131
- parameters = compose_parameters @defaults[:chat_completion_model_name], options
128
+ parameters = compose_parameters @defaults[:chat_completion_model_name], options, &block
132
129
  parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
133
130
 
134
131
  if functions
@@ -137,25 +134,11 @@ module Langchain::LLM
137
134
  parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
138
135
  end
139
136
 
140
- if (streaming = block_given?)
141
- parameters[:stream] = proc do |chunk, _bytesize|
142
- delta = chunk.dig("choices", 0, "delta")
143
- content = delta["content"]
144
- additional_kwargs = {function_call: delta["function_call"]}.compact
145
- yield Langchain::AIMessage.new(content, additional_kwargs)
146
- end
147
- end
137
+ response = with_api_error_handling { client.chat(parameters: parameters) }
148
138
 
149
- response = with_api_error_handling do
150
- client.chat(parameters: parameters)
151
- end
139
+ return if block
152
140
 
153
- unless streaming
154
- message = response.dig("choices", 0, "message")
155
- content = message["content"]
156
- additional_kwargs = {function_call: message["function_call"]}.compact
157
- Langchain::AIMessage.new(content.to_s, additional_kwargs)
158
- end
141
+ Langchain::LLM::OpenAIResponse.new(response)
159
142
  end
160
143
 
161
144
  #
@@ -171,6 +154,7 @@ module Langchain::LLM
171
154
  prompt = prompt_template.format(text: text)
172
155
 
173
156
  complete(prompt: prompt, temperature: @defaults[:temperature])
157
+ # Should this return a Langchain::LLM::OpenAIResponse as well?
174
158
  end
175
159
 
176
160
  private
@@ -191,12 +175,18 @@ module Langchain::LLM
191
175
  response.dig("choices", 0, "text")
192
176
  end
193
177
 
194
- def compose_parameters(model, params)
195
- default_params = {model: model, temperature: @defaults[:temperature]}
196
-
178
+ def compose_parameters(model, params, &block)
179
+ default_params = {model: model, temperature: @defaults[:temperature], n: @defaults[:n]}
197
180
  default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
181
+ parameters = default_params.merge(params)
198
182
 
199
- default_params.merge(params)
183
+ if block
184
+ parameters[:stream] = proc do |chunk, _bytesize|
185
+ yield chunk.dig("choices", 0)
186
+ end
187
+ end
188
+
189
+ parameters
200
190
  end
201
191
 
202
192
  def compose_chat_messages(prompt:, messages: [], context: "", examples: [])
@@ -206,9 +196,9 @@ module Langchain::LLM
206
196
 
207
197
  history.concat transform_messages(messages) unless messages.empty?
208
198
 
209
- unless context.nil? || context.to_s.empty?
199
+ unless context.nil? || context.empty?
210
200
  history.reject! { |message| message[:role] == "system" }
211
- history.prepend({role: "system", content: context.content})
201
+ history.prepend({role: "system", content: context})
212
202
  end
213
203
 
214
204
  unless prompt.empty?
@@ -225,14 +215,16 @@ module Langchain::LLM
225
215
  def transform_messages(messages)
226
216
  messages.map do |message|
227
217
  {
228
- role: ROLE_MAPPING.fetch(message.type, message.type),
229
- content: message.content
218
+ role: message[:role],
219
+ content: message[:content]
230
220
  }
231
221
  end
232
222
  end
233
223
 
234
224
  def with_api_error_handling
235
225
  response = yield
226
+ return if response.empty?
227
+
236
228
  raise Langchain::LLM::ApiError.new "OpenAI API error: #{response.dig("error", "message")}" if response&.dig("error")
237
229
 
238
230
  response
@@ -241,5 +233,10 @@ module Langchain::LLM
241
233
  def validate_max_tokens(messages, model)
242
234
  LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
243
235
  end
236
+
237
+ def extract_response(response)
238
+ results = response.dig("choices").map { |choice| choice.dig("message", "content") }
239
+ (results.size == 1) ? results.first : results
240
+ end
244
241
  end
245
242
  end
@@ -47,44 +47,40 @@ module Langchain::LLM
47
47
  # Generate an embedding for a given text
48
48
  #
49
49
  # @param text [String] The text to generate an embedding for
50
- # @return [Hash] The embedding
50
+ # @return [Langchain::LLM::ReplicateResponse] Response object
51
51
  #
52
52
  def embed(text:)
53
53
  response = embeddings_model.predict(input: text)
54
54
 
55
55
  until response.finished?
56
56
  response.refetch
57
- sleep(1)
57
+ sleep(0.1)
58
58
  end
59
59
 
60
- response.output
60
+ Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:embeddings_model_name])
61
61
  end
62
62
 
63
63
  #
64
64
  # Generate a completion for a given prompt
65
65
  #
66
66
  # @param prompt [String] The prompt to generate a completion for
67
- # @return [Hash] The completion
67
+ # @return [Langchain::LLM::ReplicateResponse] Reponse object
68
68
  #
69
69
  def complete(prompt:, **params)
70
70
  response = completion_model.predict(prompt: prompt)
71
71
 
72
72
  until response.finished?
73
73
  response.refetch
74
- sleep(1)
74
+ sleep(0.1)
75
75
  end
76
76
 
77
- # Response comes back as an array of strings, e.g.: ["Hi", "how ", "are ", "you?"]
78
- # The first array element is missing a space at the end, so we add it manually
79
- response.output[0] += " "
80
-
81
- response.output.join
77
+ Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:completion_model_name])
82
78
  end
83
79
 
84
80
  # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
85
81
  def chat(...)
86
82
  response_text = complete(...)
87
- Langchain::AIMessage.new(response_text)
83
+ ::Langchain::Conversation::Response.new(response_text)
88
84
  end
89
85
 
90
86
  #
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class AI21Response < BaseResponse
5
+ def completions
6
+ raw_response.dig(:completions)
7
+ end
8
+
9
+ def completion
10
+ completions.dig(0, :data, :text)
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,29 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class AnthropicResponse < BaseResponse
5
+ def model
6
+ raw_response.dig("model")
7
+ end
8
+
9
+ def completion
10
+ completions.first
11
+ end
12
+
13
+ def completions
14
+ [raw_response.dig("completion")]
15
+ end
16
+
17
+ def stop_reason
18
+ raw_response.dig("stop_reason")
19
+ end
20
+
21
+ def stop
22
+ raw_response.dig("stop")
23
+ end
24
+
25
+ def log_id
26
+ raw_response.dig("log_id")
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,79 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module LLM
5
+ class BaseResponse
6
+ attr_reader :raw_response, :model
7
+
8
+ def initialize(raw_response, model: nil)
9
+ @raw_response = raw_response
10
+ @model = model
11
+ end
12
+
13
+ # Returns the completion text
14
+ #
15
+ # @return [String]
16
+ #
17
+ def completion
18
+ raise NotImplementedError
19
+ end
20
+
21
+ # Returns the chat completion text
22
+ #
23
+ # @return [String]
24
+ #
25
+ def chat_completion
26
+ raise NotImplementedError
27
+ end
28
+
29
+ # Return the first embedding
30
+ #
31
+ # @return [Array<Float>]
32
+ def embedding
33
+ raise NotImplementedError
34
+ end
35
+
36
+ # Return the completion candidates
37
+ #
38
+ # @return [Array]
39
+ def completions
40
+ raise NotImplementedError
41
+ end
42
+
43
+ # Return the chat completion candidates
44
+ #
45
+ # @return [Array]
46
+ def chat_completions
47
+ raise NotImplementedError
48
+ end
49
+
50
+ # Return the embeddings
51
+ #
52
+ # @return [Array<Array>]
53
+ def embeddings
54
+ raise NotImplementedError
55
+ end
56
+
57
+ # Number of tokens utilized in the prompt
58
+ #
59
+ # @return [Integer]
60
+ def prompt_tokens
61
+ raise NotImplementedError
62
+ end
63
+
64
+ # Number of tokens utilized to generate the completion
65
+ #
66
+ # @return [Integer]
67
+ def completion_tokens
68
+ raise NotImplementedError
69
+ end
70
+
71
+ # Total number of tokens utilized
72
+ #
73
+ # @return [Integer]
74
+ def total_tokens
75
+ raise NotImplementedError
76
+ end
77
+ end
78
+ end
79
+ end
@@ -0,0 +1,21 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class CohereResponse < BaseResponse
5
+ def embedding
6
+ embeddings.first
7
+ end
8
+
9
+ def embeddings
10
+ raw_response.dig("embeddings")
11
+ end
12
+
13
+ def completions
14
+ raw_response.dig("generations")
15
+ end
16
+
17
+ def completion
18
+ completions&.dig(0, "text")
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,36 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class GooglePalmResponse < BaseResponse
5
+ attr_reader :prompt_tokens
6
+
7
+ def initialize(raw_response, model: nil, prompt_tokens: nil)
8
+ @prompt_tokens = prompt_tokens
9
+ super(raw_response, model: model)
10
+ end
11
+
12
+ def completion
13
+ completions&.dig(0, "output")
14
+ end
15
+
16
+ def embedding
17
+ embeddings.first
18
+ end
19
+
20
+ def completions
21
+ raw_response.dig("candidates")
22
+ end
23
+
24
+ def chat_completion
25
+ chat_completions&.dig(0, "content")
26
+ end
27
+
28
+ def chat_completions
29
+ raw_response.dig("candidates")
30
+ end
31
+
32
+ def embeddings
33
+ [raw_response.dig("embedding", "value")]
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class HuggingFaceResponse < BaseResponse
5
+ def embeddings
6
+ [raw_response]
7
+ end
8
+
9
+ def embedding
10
+ embeddings.first
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class OllamaResponse < BaseResponse
5
+ def initialize(raw_response, model: nil, prompt_tokens: nil)
6
+ @prompt_tokens = prompt_tokens
7
+ super(raw_response, model: model)
8
+ end
9
+
10
+ def completion
11
+ raw_response.first
12
+ end
13
+
14
+ def completions
15
+ raw_response.is_a?(String) ? [raw_response] : []
16
+ end
17
+
18
+ def embedding
19
+ embeddings.first
20
+ end
21
+
22
+ def embeddings
23
+ [raw_response&.dig("embedding")]
24
+ end
25
+ end
26
+ end