langchainrb 0.7.5 → 0.12.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (95) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +78 -0
  3. data/README.md +113 -56
  4. data/lib/langchain/assistants/assistant.rb +213 -0
  5. data/lib/langchain/assistants/message.rb +58 -0
  6. data/lib/langchain/assistants/thread.rb +34 -0
  7. data/lib/langchain/chunker/markdown.rb +37 -0
  8. data/lib/langchain/chunker/recursive_text.rb +0 -2
  9. data/lib/langchain/chunker/semantic.rb +1 -3
  10. data/lib/langchain/chunker/sentence.rb +0 -2
  11. data/lib/langchain/chunker/text.rb +0 -2
  12. data/lib/langchain/contextual_logger.rb +1 -1
  13. data/lib/langchain/data.rb +4 -3
  14. data/lib/langchain/llm/ai21.rb +1 -1
  15. data/lib/langchain/llm/anthropic.rb +86 -11
  16. data/lib/langchain/llm/aws_bedrock.rb +52 -0
  17. data/lib/langchain/llm/azure.rb +10 -97
  18. data/lib/langchain/llm/base.rb +3 -2
  19. data/lib/langchain/llm/cohere.rb +5 -7
  20. data/lib/langchain/llm/google_palm.rb +4 -2
  21. data/lib/langchain/llm/google_vertex_ai.rb +151 -0
  22. data/lib/langchain/llm/hugging_face.rb +1 -1
  23. data/lib/langchain/llm/llama_cpp.rb +18 -16
  24. data/lib/langchain/llm/mistral_ai.rb +68 -0
  25. data/lib/langchain/llm/ollama.rb +209 -27
  26. data/lib/langchain/llm/openai.rb +138 -170
  27. data/lib/langchain/llm/prompts/ollama/summarize_template.yaml +9 -0
  28. data/lib/langchain/llm/replicate.rb +1 -7
  29. data/lib/langchain/llm/response/anthropic_response.rb +20 -0
  30. data/lib/langchain/llm/response/base_response.rb +7 -0
  31. data/lib/langchain/llm/response/google_palm_response.rb +4 -0
  32. data/lib/langchain/llm/response/google_vertex_ai_response.rb +33 -0
  33. data/lib/langchain/llm/response/llama_cpp_response.rb +13 -0
  34. data/lib/langchain/llm/response/mistral_ai_response.rb +39 -0
  35. data/lib/langchain/llm/response/ollama_response.rb +27 -1
  36. data/lib/langchain/llm/response/openai_response.rb +8 -0
  37. data/lib/langchain/loader.rb +3 -2
  38. data/lib/langchain/output_parsers/base.rb +0 -4
  39. data/lib/langchain/output_parsers/output_fixing_parser.rb +7 -14
  40. data/lib/langchain/output_parsers/structured_output_parser.rb +0 -10
  41. data/lib/langchain/processors/csv.rb +37 -3
  42. data/lib/langchain/processors/eml.rb +64 -0
  43. data/lib/langchain/processors/markdown.rb +17 -0
  44. data/lib/langchain/processors/pptx.rb +29 -0
  45. data/lib/langchain/prompt/loading.rb +1 -1
  46. data/lib/langchain/tool/base.rb +21 -53
  47. data/lib/langchain/tool/calculator/calculator.json +19 -0
  48. data/lib/langchain/tool/{calculator.rb → calculator/calculator.rb} +8 -16
  49. data/lib/langchain/tool/database/database.json +46 -0
  50. data/lib/langchain/tool/database/database.rb +99 -0
  51. data/lib/langchain/tool/file_system/file_system.json +57 -0
  52. data/lib/langchain/tool/file_system/file_system.rb +32 -0
  53. data/lib/langchain/tool/google_search/google_search.json +19 -0
  54. data/lib/langchain/tool/{google_search.rb → google_search/google_search.rb} +5 -15
  55. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +19 -0
  56. data/lib/langchain/tool/{ruby_code_interpreter.rb → ruby_code_interpreter/ruby_code_interpreter.rb} +8 -4
  57. data/lib/langchain/tool/vectorsearch/vectorsearch.json +24 -0
  58. data/lib/langchain/tool/vectorsearch/vectorsearch.rb +36 -0
  59. data/lib/langchain/tool/weather/weather.json +19 -0
  60. data/lib/langchain/tool/{weather.rb → weather/weather.rb} +3 -15
  61. data/lib/langchain/tool/wikipedia/wikipedia.json +19 -0
  62. data/lib/langchain/tool/{wikipedia.rb → wikipedia/wikipedia.rb} +9 -9
  63. data/lib/langchain/utils/token_length/ai21_validator.rb +6 -2
  64. data/lib/langchain/utils/token_length/base_validator.rb +1 -1
  65. data/lib/langchain/utils/token_length/cohere_validator.rb +6 -2
  66. data/lib/langchain/utils/token_length/google_palm_validator.rb +5 -1
  67. data/lib/langchain/utils/token_length/openai_validator.rb +55 -1
  68. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +1 -1
  69. data/lib/langchain/vectorsearch/base.rb +11 -4
  70. data/lib/langchain/vectorsearch/chroma.rb +10 -1
  71. data/lib/langchain/vectorsearch/elasticsearch.rb +53 -4
  72. data/lib/langchain/vectorsearch/epsilla.rb +149 -0
  73. data/lib/langchain/vectorsearch/hnswlib.rb +5 -1
  74. data/lib/langchain/vectorsearch/milvus.rb +4 -2
  75. data/lib/langchain/vectorsearch/pgvector.rb +14 -4
  76. data/lib/langchain/vectorsearch/pinecone.rb +8 -5
  77. data/lib/langchain/vectorsearch/qdrant.rb +16 -4
  78. data/lib/langchain/vectorsearch/weaviate.rb +20 -2
  79. data/lib/langchain/version.rb +1 -1
  80. data/lib/langchain.rb +20 -5
  81. metadata +182 -45
  82. data/lib/langchain/agent/agents.md +0 -54
  83. data/lib/langchain/agent/base.rb +0 -20
  84. data/lib/langchain/agent/react_agent/react_agent_prompt.yaml +0 -26
  85. data/lib/langchain/agent/react_agent.rb +0 -131
  86. data/lib/langchain/agent/sql_query_agent/sql_query_agent_answer_prompt.yaml +0 -11
  87. data/lib/langchain/agent/sql_query_agent/sql_query_agent_sql_prompt.yaml +0 -21
  88. data/lib/langchain/agent/sql_query_agent.rb +0 -82
  89. data/lib/langchain/conversation/context.rb +0 -8
  90. data/lib/langchain/conversation/memory.rb +0 -86
  91. data/lib/langchain/conversation/message.rb +0 -48
  92. data/lib/langchain/conversation/prompt.rb +0 -8
  93. data/lib/langchain/conversation/response.rb +0 -8
  94. data/lib/langchain/conversation.rb +0 -93
  95. data/lib/langchain/tool/database.rb +0 -90
@@ -4,223 +4,196 @@ module Langchain::LLM
4
4
  # LLM interface for OpenAI APIs: https://platform.openai.com/overview
5
5
  #
6
6
  # Gem requirements:
7
- # gem "ruby-openai", "~> 5.2.0"
7
+ # gem "ruby-openai", "~> 6.3.0"
8
8
  #
9
9
  # Usage:
10
- # openai = Langchain::LLM::OpenAI.new(api_key:, llm_options: {})
11
- #
10
+ # openai = Langchain::LLM::OpenAI.new(
11
+ # api_key: ENV["OPENAI_API_KEY"],
12
+ # llm_options: {}, # Available options: https://github.com/alexrudall/ruby-openai/blob/main/lib/openai/client.rb#L5-L13
13
+ # default_options: {}
14
+ # )
12
15
  class OpenAI < Base
13
16
  DEFAULTS = {
14
17
  n: 1,
15
18
  temperature: 0.0,
16
- completion_model_name: "gpt-3.5-turbo",
17
19
  chat_completion_model_name: "gpt-3.5-turbo",
18
- embeddings_model_name: "text-embedding-ada-002",
19
- dimension: 1536
20
+ embeddings_model_name: "text-embedding-3-small"
20
21
  }.freeze
21
22
 
22
- LEGACY_COMPLETION_MODELS = %w[
23
- ada
24
- babbage
25
- curie
26
- davinci
27
- ].freeze
23
+ EMBEDDING_SIZES = {
24
+ "text-embedding-ada-002" => 1536,
25
+ "text-embedding-3-large" => 3072,
26
+ "text-embedding-3-small" => 1536
27
+ }.freeze
28
28
 
29
29
  LENGTH_VALIDATOR = Langchain::Utils::TokenLength::OpenAIValidator
30
30
 
31
- attr_accessor :functions
32
- attr_accessor :response_chunks
31
+ attr_reader :defaults
33
32
 
33
+ # Initialize an OpenAI LLM instance
34
+ #
35
+ # @param api_key [String] The API key to use
36
+ # @param client_options [Hash] Options to pass to the OpenAI::Client constructor
34
37
  def initialize(api_key:, llm_options: {}, default_options: {})
35
38
  depends_on "ruby-openai", req: "openai"
36
39
 
37
40
  @client = ::OpenAI::Client.new(access_token: api_key, **llm_options)
41
+
38
42
  @defaults = DEFAULTS.merge(default_options)
39
43
  end
40
44
 
41
- #
42
45
  # Generate an embedding for a given text
43
46
  #
44
47
  # @param text [String] The text to generate an embedding for
45
- # @param params extra parameters passed to OpenAI::Client#embeddings
48
+ # @param model [String] ID of the model to use
49
+ # @param encoding_format [String] The format to return the embeddings in. Can be either float or base64.
50
+ # @param user [String] A unique identifier representing your end-user
46
51
  # @return [Langchain::LLM::OpenAIResponse] Response object
47
- #
48
- def embed(text:, **params)
49
- parameters = {model: @defaults[:embeddings_model_name], input: text}
52
+ def embed(
53
+ text:,
54
+ model: defaults[:embeddings_model_name],
55
+ encoding_format: nil,
56
+ user: nil,
57
+ dimensions: @defaults[:dimensions]
58
+ )
59
+ raise ArgumentError.new("text argument is required") if text.empty?
60
+ raise ArgumentError.new("model argument is required") if model.empty?
61
+ raise ArgumentError.new("encoding_format must be either float or base64") if encoding_format && %w[float base64].include?(encoding_format)
62
+
63
+ parameters = {
64
+ input: text,
65
+ model: model
66
+ }
67
+ parameters[:encoding_format] = encoding_format if encoding_format
68
+ parameters[:user] = user if user
69
+
70
+ if dimensions
71
+ parameters[:dimensions] = dimensions
72
+ elsif EMBEDDING_SIZES.key?(model)
73
+ parameters[:dimensions] = EMBEDDING_SIZES[model]
74
+ end
50
75
 
51
76
  validate_max_tokens(text, parameters[:model])
52
77
 
53
78
  response = with_api_error_handling do
54
- client.embeddings(parameters: parameters.merge(params))
79
+ client.embeddings(parameters: parameters)
55
80
  end
56
81
 
57
82
  Langchain::LLM::OpenAIResponse.new(response)
58
83
  end
59
84
 
60
- #
85
+ # rubocop:disable Style/ArgumentsForwarding
61
86
  # Generate a completion for a given prompt
62
87
  #
63
88
  # @param prompt [String] The prompt to generate a completion for
64
- # @param params extra parameters passed to OpenAI::Client#complete
65
- # @return [Langchain::LLM::Response::OpenaAI] Response object
66
- #
89
+ # @param params [Hash] The parameters to pass to the `chat()` method
90
+ # @return [Langchain::LLM::OpenAIResponse] Response object
67
91
  def complete(prompt:, **params)
68
- parameters = compose_parameters @defaults[:completion_model_name], params
92
+ warn "DEPRECATED: `Langchain::LLM::OpenAI#complete` is deprecated, and will be removed in the next major version. Use `Langchain::LLM::OpenAI#chat` instead."
69
93
 
70
- return legacy_complete(prompt, parameters) if is_legacy_model?(parameters[:model])
71
-
72
- parameters[:messages] = compose_chat_messages(prompt: prompt)
73
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model], parameters[:max_tokens])
74
-
75
- response = with_api_error_handling do
76
- client.chat(parameters: parameters)
94
+ if params[:stop_sequences]
95
+ params[:stop] = params.delete(:stop_sequences)
77
96
  end
78
-
79
- Langchain::LLM::OpenAIResponse.new(response)
97
+ # Should we still accept the `messages: []` parameter here?
98
+ messages = [{role: "user", content: prompt}]
99
+ chat(messages: messages, **params)
80
100
  end
101
+ # rubocop:enable Style/ArgumentsForwarding
102
+
103
+ # Generate a chat completion for given messages.
104
+ #
105
+ # @param messages [Array<Hash>] List of messages comprising the conversation so far
106
+ # @param model [String] ID of the model to use
107
+ def chat(
108
+ messages: [],
109
+ model: defaults[:chat_completion_model_name],
110
+ frequency_penalty: nil,
111
+ logit_bias: nil,
112
+ logprobs: nil,
113
+ top_logprobs: nil,
114
+ max_tokens: nil,
115
+ n: defaults[:n],
116
+ presence_penalty: nil,
117
+ response_format: nil,
118
+ seed: nil,
119
+ stop: nil,
120
+ stream: nil,
121
+ temperature: defaults[:temperature],
122
+ top_p: nil,
123
+ tools: [],
124
+ tool_choice: nil,
125
+ user: nil,
126
+ &block
127
+ )
128
+ raise ArgumentError.new("messages argument is required") if messages.empty?
129
+ raise ArgumentError.new("model argument is required") if model.empty?
130
+ raise ArgumentError.new("'tool_choice' is only allowed when 'tools' are specified.") if tool_choice && tools.empty?
131
+
132
+ parameters = {
133
+ messages: messages,
134
+ model: model
135
+ }
136
+ parameters[:frequency_penalty] = frequency_penalty if frequency_penalty
137
+ parameters[:logit_bias] = logit_bias if logit_bias
138
+ parameters[:logprobs] = logprobs if logprobs
139
+ parameters[:top_logprobs] = top_logprobs if top_logprobs
140
+ # TODO: Fix max_tokens validation to account for tools/functions
141
+ parameters[:max_tokens] = max_tokens if max_tokens # || validate_max_tokens(parameters[:messages], parameters[:model])
142
+ parameters[:n] = n if n
143
+ parameters[:presence_penalty] = presence_penalty if presence_penalty
144
+ parameters[:response_format] = response_format if response_format
145
+ parameters[:seed] = seed if seed
146
+ parameters[:stop] = stop if stop
147
+ parameters[:stream] = stream if stream
148
+ parameters[:temperature] = temperature if temperature
149
+ parameters[:top_p] = top_p if top_p
150
+ parameters[:tools] = tools if tools.any?
151
+ parameters[:tool_choice] = tool_choice if tool_choice
152
+ parameters[:user] = user if user
153
+
154
+ # TODO: Clean this part up
155
+ if block
156
+ @response_chunks = []
157
+ parameters[:stream] = proc do |chunk, _bytesize|
158
+ chunk_content = chunk.dig("choices", 0)
159
+ @response_chunks << chunk
160
+ yield chunk_content
161
+ end
162
+ end
81
163
 
82
- #
83
- # Generate a chat completion for a given prompt or messages.
84
- #
85
- # == Examples
86
- #
87
- # # simplest case, just give a prompt
88
- # openai.chat prompt: "When was Ruby first released?"
89
- #
90
- # # prompt plus some context about how to respond
91
- # openai.chat context: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
92
- #
93
- # # full control over messages that get sent, equivilent to the above
94
- # openai.chat messages: [
95
- # {
96
- # role: "system",
97
- # content: "You are RubyGPT, a helpful chat bot for helping people learn Ruby", prompt: "Does Ruby have a REPL like IPython?"
98
- # },
99
- # {
100
- # role: "user",
101
- # content: "When was Ruby first released?"
102
- # }
103
- # ]
104
- #
105
- # # few-short prompting with examples
106
- # openai.chat prompt: "When was factory_bot released?",
107
- # examples: [
108
- # {
109
- # role: "user",
110
- # content: "When was Ruby on Rails released?"
111
- # }
112
- # {
113
- # role: "assistant",
114
- # content: "2004"
115
- # },
116
- # ]
117
- #
118
- # @param prompt [String] The prompt to generate a chat completion for
119
- # @param messages [Array<Hash>] The messages that have been sent in the conversation
120
- # @param context [String] An initial context to provide as a system message, ie "You are RubyGPT, a helpful chat bot for helping people learn Ruby"
121
- # @param examples [Array<Hash>] Examples of messages to provide to the model. Useful for Few-Shot Prompting
122
- # @param options [Hash] extra parameters passed to OpenAI::Client#chat
123
- # @yield [Hash] Stream responses back one token at a time
124
- # @return [Langchain::LLM::OpenAIResponse] Response object
125
- #
126
- def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
127
- raise ArgumentError.new(":prompt or :messages argument is expected") if prompt.empty? && messages.empty?
128
-
129
- parameters = compose_parameters @defaults[:chat_completion_model_name], options, &block
130
- parameters[:messages] = compose_chat_messages(prompt: prompt, messages: messages, context: context, examples: examples)
131
-
132
- if functions
133
- parameters[:functions] = functions
134
- else
135
- parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model], parameters[:max_tokens])
164
+ response = with_api_error_handling do
165
+ client.chat(parameters: parameters)
136
166
  end
137
167
 
138
- response = with_api_error_handling { client.chat(parameters: parameters) }
139
168
  response = response_from_chunks if block
169
+ reset_response_chunks
170
+
140
171
  Langchain::LLM::OpenAIResponse.new(response)
141
172
  end
142
173
 
143
- #
144
174
  # Generate a summary for a given text
145
175
  #
146
176
  # @param text [String] The text to generate a summary for
147
177
  # @return [String] The summary
148
- #
149
178
  def summarize(text:)
150
179
  prompt_template = Langchain::Prompt.load_from_path(
151
180
  file_path: Langchain.root.join("langchain/llm/prompts/summarize_template.yaml")
152
181
  )
153
182
  prompt = prompt_template.format(text: text)
154
183
 
155
- complete(prompt: prompt, temperature: @defaults[:temperature])
156
- # Should this return a Langchain::LLM::OpenAIResponse as well?
157
- end
158
-
159
- private
160
-
161
- def is_legacy_model?(model)
162
- LEGACY_COMPLETION_MODELS.any? { |legacy_model| model.include?(legacy_model) }
184
+ complete(prompt: prompt)
163
185
  end
164
186
 
165
- def legacy_complete(prompt, parameters)
166
- Langchain.logger.warn "DEPRECATION WARNING: The model #{parameters[:model]} is deprecated. Please use gpt-3.5-turbo instead. Details: https://platform.openai.com/docs/deprecations/2023-07-06-gpt-and-embeddings"
167
-
168
- parameters[:prompt] = prompt
169
- parameters[:max_tokens] = validate_max_tokens(prompt, parameters[:model])
170
-
171
- response = with_api_error_handling do
172
- client.completions(parameters: parameters)
173
- end
174
- response.dig("choices", 0, "text")
175
- end
176
-
177
- def compose_parameters(model, params, &block)
178
- default_params = {model: model, temperature: @defaults[:temperature], n: @defaults[:n]}
179
- default_params[:stop] = params.delete(:stop_sequences) if params[:stop_sequences]
180
- parameters = default_params.merge(params)
181
-
182
- if block
183
- @response_chunks = []
184
- parameters[:stream] = proc do |chunk, _bytesize|
185
- chunk_content = chunk.dig("choices", 0)
186
- @response_chunks << chunk
187
- yield chunk_content
188
- end
189
- end
190
-
191
- parameters
187
+ def default_dimensions
188
+ @defaults[:dimensions] || EMBEDDING_SIZES.fetch(defaults[:embeddings_model_name])
192
189
  end
193
190
 
194
- def compose_chat_messages(prompt:, messages: [], context: "", examples: [])
195
- history = []
196
-
197
- history.concat transform_messages(examples) unless examples.empty?
198
-
199
- history.concat transform_messages(messages) unless messages.empty?
200
-
201
- unless context.nil? || context.empty?
202
- history.reject! { |message| message[:role] == "system" }
203
- history.prepend({role: "system", content: context})
204
- end
205
-
206
- unless prompt.empty?
207
- if history.last && history.last[:role] == "user"
208
- history.last[:content] += "\n#{prompt}"
209
- else
210
- history.append({role: "user", content: prompt})
211
- end
212
- end
191
+ private
213
192
 
214
- history
215
- end
193
+ attr_reader :response_chunks
216
194
 
217
- def transform_messages(messages)
218
- messages.map do |message|
219
- {
220
- role: message[:role],
221
- content: message[:content]
222
- }
223
- end
195
+ def reset_response_chunks
196
+ @response_chunks = []
224
197
  end
225
198
 
226
199
  def with_api_error_handling
@@ -233,27 +206,22 @@ module Langchain::LLM
233
206
  end
234
207
 
235
208
  def validate_max_tokens(messages, model, max_tokens = nil)
236
- LENGTH_VALIDATOR.validate_max_tokens!(messages, model, max_tokens: max_tokens)
237
- end
238
-
239
- def extract_response(response)
240
- results = response.dig("choices").map { |choice| choice.dig("message", "content") }
241
- (results.size == 1) ? results.first : results
209
+ LENGTH_VALIDATOR.validate_max_tokens!(messages, model, max_tokens: max_tokens, llm: self)
242
210
  end
243
211
 
244
212
  def response_from_chunks
245
- @response_chunks.first&.slice("id", "object", "created", "model")&.merge(
213
+ grouped_chunks = @response_chunks.group_by { |chunk| chunk.dig("choices", 0, "index") }
214
+ final_choices = grouped_chunks.map do |index, chunks|
246
215
  {
247
- "choices" => [
248
- {
249
- "message" => {
250
- "role" => "assistant",
251
- "content" => @response_chunks.map { |chunk| chunk.dig("choices", 0, "delta", "content") }.join
252
- }
253
- }
254
- ]
216
+ "index" => index,
217
+ "message" => {
218
+ "role" => "assistant",
219
+ "content" => chunks.map { |chunk| chunk.dig("choices", 0, "delta", "content") }.join
220
+ },
221
+ "finish_reason" => chunks.last.dig("choices", 0, "finish_reason")
255
222
  }
256
- )
223
+ end
224
+ @response_chunks.first&.slice("id", "object", "created", "model")&.merge({"choices" => final_choices})
257
225
  end
258
226
  end
259
227
  end
@@ -0,0 +1,9 @@
1
+ _type: prompt
2
+ input_variables:
3
+ - text
4
+ template: |
5
+ Write a concise summary of the following TEXT. Do not include the word summary, just provide the summary.
6
+
7
+ TEXT: {text}
8
+
9
+ CONCISE SUMMARY:
@@ -24,7 +24,7 @@ module Langchain::LLM
24
24
  # TODO: Design the interface to pass and use different models
25
25
  completion_model_name: "replicate/vicuna-13b",
26
26
  embeddings_model_name: "creatorrr/all-mpnet-base-v2",
27
- dimension: 384
27
+ dimensions: 384
28
28
  }.freeze
29
29
 
30
30
  #
@@ -77,12 +77,6 @@ module Langchain::LLM
77
77
  Langchain::LLM::ReplicateResponse.new(response, model: @defaults[:completion_model_name])
78
78
  end
79
79
 
80
- # Cohere does not have a dedicated chat endpoint, so instead we call `complete()`
81
- def chat(...)
82
- response_text = complete(...)
83
- ::Langchain::Conversation::Response.new(response_text)
84
- end
85
-
86
80
  #
87
81
  # Generate a summary for a given text
88
82
  #
@@ -10,6 +10,10 @@ module Langchain::LLM
10
10
  completions.first
11
11
  end
12
12
 
13
+ def chat_completion
14
+ raw_response.dig("content", 0, "text")
15
+ end
16
+
13
17
  def completions
14
18
  [raw_response.dig("completion")]
15
19
  end
@@ -25,5 +29,21 @@ module Langchain::LLM
25
29
  def log_id
26
30
  raw_response.dig("log_id")
27
31
  end
32
+
33
+ def prompt_tokens
34
+ raw_response.dig("usage", "input_tokens").to_i
35
+ end
36
+
37
+ def completion_tokens
38
+ raw_response.dig("usage", "output_tokens").to_i
39
+ end
40
+
41
+ def total_tokens
42
+ prompt_tokens + completion_tokens
43
+ end
44
+
45
+ def role
46
+ raw_response.dig("role")
47
+ end
28
48
  end
29
49
  end
@@ -13,6 +13,13 @@ module Langchain
13
13
  @model = model
14
14
  end
15
15
 
16
+ # Returns the timestamp when the response was created
17
+ #
18
+ # @return [Time]
19
+ def created_at
20
+ raise NotImplementedError
21
+ end
22
+
16
23
  # Returns the completion text
17
24
  #
18
25
  # @return [String]
@@ -32,5 +32,9 @@ module Langchain::LLM
32
32
  def embeddings
33
33
  [raw_response.dig("embedding", "value")]
34
34
  end
35
+
36
+ def role
37
+ "assistant"
38
+ end
35
39
  end
36
40
  end
@@ -0,0 +1,33 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class GoogleVertexAiResponse < BaseResponse
5
+ attr_reader :prompt_tokens
6
+
7
+ def initialize(raw_response, model: nil)
8
+ @prompt_tokens = prompt_tokens
9
+ super(raw_response, model: model)
10
+ end
11
+
12
+ def completion
13
+ # completions&.dig(0, "output")
14
+ raw_response.predictions[0]["content"]
15
+ end
16
+
17
+ def embedding
18
+ embeddings.first
19
+ end
20
+
21
+ def completions
22
+ raw_response.predictions.map { |p| p["content"] }
23
+ end
24
+
25
+ def total_tokens
26
+ raw_response.dig(:predictions, 0, :embeddings, :statistics, :token_count)
27
+ end
28
+
29
+ def embeddings
30
+ [raw_response.dig(:predictions, 0, :embeddings, :values)]
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,13 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class LlamaCppResponse < BaseResponse
5
+ def embedding
6
+ embeddings
7
+ end
8
+
9
+ def embeddings
10
+ raw_response.embeddings
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,39 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain::LLM
4
+ class MistralAIResponse < BaseResponse
5
+ def model
6
+ raw_response["model"]
7
+ end
8
+
9
+ def chat_completion
10
+ raw_response.dig("choices", 0, "message", "content")
11
+ end
12
+
13
+ def role
14
+ raw_response.dig("choices", 0, "message", "role")
15
+ end
16
+
17
+ def embedding
18
+ raw_response.dig("data", 0, "embedding")
19
+ end
20
+
21
+ def prompt_tokens
22
+ raw_response.dig("usage", "prompt_tokens")
23
+ end
24
+
25
+ def total_tokens
26
+ raw_response.dig("usage", "total_tokens")
27
+ end
28
+
29
+ def completion_tokens
30
+ raw_response.dig("usage", "completion_tokens")
31
+ end
32
+
33
+ def created_at
34
+ if raw_response.dig("created_at")
35
+ Time.at(raw_response.dig("created_at"))
36
+ end
37
+ end
38
+ end
39
+ end
@@ -7,8 +7,18 @@ module Langchain::LLM
7
7
  super(raw_response, model: model)
8
8
  end
9
9
 
10
+ def created_at
11
+ if raw_response.dig("created_at")
12
+ Time.parse(raw_response.dig("created_at"))
13
+ end
14
+ end
15
+
16
+ def chat_completion
17
+ raw_response.dig("message", "content")
18
+ end
19
+
10
20
  def completion
11
- raw_response.first
21
+ completions.first
12
22
  end
13
23
 
14
24
  def completions
@@ -22,5 +32,21 @@ module Langchain::LLM
22
32
  def embeddings
23
33
  [raw_response&.dig("embedding")]
24
34
  end
35
+
36
+ def role
37
+ "assistant"
38
+ end
39
+
40
+ def prompt_tokens
41
+ raw_response.dig("prompt_eval_count")
42
+ end
43
+
44
+ def completion_tokens
45
+ raw_response.dig("eval_count")
46
+ end
47
+
48
+ def total_tokens
49
+ prompt_tokens + completion_tokens
50
+ end
25
51
  end
26
52
  end
@@ -16,10 +16,18 @@ module Langchain::LLM
16
16
  completions&.dig(0, "message", "content")
17
17
  end
18
18
 
19
+ def role
20
+ completions&.dig(0, "message", "role")
21
+ end
22
+
19
23
  def chat_completion
20
24
  completion
21
25
  end
22
26
 
27
+ def tool_calls
28
+ chat_completions&.dig(0, "message", "tool_calls")
29
+ end
30
+
23
31
  def embedding
24
32
  embeddings&.first
25
33
  end
@@ -37,9 +37,10 @@ module Langchain
37
37
  # @param path [String | Pathname] path to file or URL
38
38
  # @param options [Hash] options passed to the processor class used to process the data
39
39
  # @return [Langchain::Loader] loader instance
40
- def initialize(path, options = {})
40
+ def initialize(path, options = {}, chunker: Langchain::Chunker::Text)
41
41
  @options = options
42
42
  @path = path
43
+ @chunker = chunker
43
44
  end
44
45
 
45
46
  # Is the path a URL?
@@ -112,7 +113,7 @@ module Langchain
112
113
  processor_klass.new(@options).parse(@raw_data)
113
114
  end
114
115
 
115
- Langchain::Data.new(result)
116
+ Langchain::Data.new(result, source: @options[:source], chunker: @chunker)
116
117
  end
117
118
 
118
119
  def processor_klass
@@ -5,18 +5,15 @@ module Langchain::OutputParsers
5
5
  #
6
6
  # @abstract
7
7
  class Base
8
- #
9
8
  # Parse the output of an LLM call.
10
9
  #
11
10
  # @param text - LLM output to parse.
12
11
  #
13
12
  # @return [Object] Parsed output.
14
- #
15
13
  def parse(text:)
16
14
  raise NotImplementedError
17
15
  end
18
16
 
19
- #
20
17
  # Return a string describing the format of the output.
21
18
  #
22
19
  # @return [String] Format instructions.
@@ -27,7 +24,6 @@ module Langchain::OutputParsers
27
24
  # "foo": "bar"
28
25
  # }
29
26
  # ```
30
- #
31
27
  def get_format_instructions
32
28
  raise NotImplementedError
33
29
  end