langchainrb 0.13.5 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d7eac7a6ba7767f6a3f84ee808fa4810eaa1843776695ab0225ddd6b77cf7a73
4
- data.tar.gz: e9f7c0170fc2a8dbf443f1bac24874878ee0fbba7e0495bf65a8df969d3d86e6
3
+ metadata.gz: 68900cd116cf0fb1b77376a4906e5551f0d578ee2bb47c7ec86d32bf44f84e33
4
+ data.tar.gz: f68782c3cdc856799778618d78b6411a85b0c69adf6a4d33489b8025fdca3dce
5
5
  SHA512:
6
- metadata.gz: e4d14ac64e54e5c7245a9586dfb4899154793ea466f9564a510eb3dfe17a3a7229cf61e408445b38fec37500065b5e1ee725afa634284bea5538abac0766237f
7
- data.tar.gz: e8fe3e1639a3f2ed087436610dd1653e775703c1c6cc83f7f52eb7d3fb46db554e7be790bc6bc2ddf18ec4e3c26dddbe1ec72e8f25603db1192e5a111d0f9543
6
+ metadata.gz: 158410fd769caaf9074eddc1143ddee9256ac5a466a510c32b74d337eba62fab80b676661cbf1673604d236014a5cb4defdd4743e71abb713a659ddea0fe5e8c
7
+ data.tar.gz: 2e956356a443ff37ad711f6c42f8c4940925bcee4be075b403c78c3f702b487c12790dca9ba7d68a01acaf1c245b2910650b3f938e80cedd1fc2d5af14f7ffa8
data/CHANGELOG.md CHANGED
@@ -1,5 +1,11 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.14.0] - 2024-07-12
4
+ - Removed TokenLength validators
5
+ - Assistant works with a Mistral LLM now
6
+ - Assistant keeps track of tokens used
7
+ - Misc fixes and improvements
8
+
3
9
  ## [0.13.5] - 2024-07-01
4
10
  - Add Milvus#remove_texts() method
5
11
  - Langchain::Assistant has a `state` now
data/README.md CHANGED
@@ -428,25 +428,10 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
428
428
  ```ruby
429
429
  llm = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
430
430
  ```
431
- 2. Instantiate a Thread. Threads keep track of the messages in the Assistant conversation.
432
- ```ruby
433
- thread = Langchain::Thread.new
434
- ```
435
- You can pass old message from previously using the Assistant:
436
- ```ruby
437
- thread.messages = messages
438
- ```
439
- Messages contain the conversation history and the whole message history is sent to the LLM every time. A Message belongs to 1 of the 4 roles:
440
- * `Message(role: "system")` message usually contains the instructions.
441
- * `Message(role: "user")` messages come from the user.
442
- * `Message(role: "assistant")` messages are produced by the LLM.
443
- * `Message(role: "tool")` messages are sent in response to tool calls with tool outputs.
444
-
445
- 3. Instantiate an Assistant
431
+ 2. Instantiate an Assistant
446
432
  ```ruby
447
433
  assistant = Langchain::Assistant.new(
448
434
  llm: llm,
449
- thread: thread,
450
435
  instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
451
436
  tools: [
452
437
  Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
@@ -482,7 +467,7 @@ assistant.add_message_and_run content: "What about Sacramento, CA?", auto_tool_e
482
467
  ### Accessing Thread messages
483
468
  You can access the messages in a Thread by calling `assistant.thread.messages`.
484
469
  ```ruby
485
- assistant.thread.messages
470
+ assistant.messages
486
471
  ```
487
472
 
488
473
  The Assistant checks the context window limits before every request to the LLM and remove oldest thread messages one by one if the context window is exceeded.
@@ -16,13 +16,15 @@ module Langchain
16
16
  def_delegators :thread, :messages, :messages=
17
17
 
18
18
  attr_reader :llm, :thread, :instructions, :state
19
+ attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens
19
20
  attr_accessor :tools
20
21
 
21
22
  SUPPORTED_LLMS = [
22
23
  Langchain::LLM::Anthropic,
23
- Langchain::LLM::OpenAI,
24
24
  Langchain::LLM::GoogleGemini,
25
- Langchain::LLM::GoogleVertexAI
25
+ Langchain::LLM::GoogleVertexAI,
26
+ Langchain::LLM::Ollama,
27
+ Langchain::LLM::OpenAI
26
28
  ]
27
29
 
28
30
  # Create a new assistant
@@ -40,6 +42,9 @@ module Langchain
40
42
  unless SUPPORTED_LLMS.include?(llm.class)
41
43
  raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
42
44
  end
45
+ if llm.is_a?(Langchain::LLM::Ollama)
46
+ raise ArgumentError, "Currently only `mistral:7b-instruct-v0.3-fp16` model is supported for Ollama LLM" unless llm.defaults[:completion_model_name] == "mistral:7b-instruct-v0.3-fp16"
47
+ end
43
48
  raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
44
49
 
45
50
  @llm = llm
@@ -48,13 +53,15 @@ module Langchain
48
53
  @instructions = instructions
49
54
  @state = :ready
50
55
 
56
+ @total_prompt_tokens = 0
57
+ @total_completion_tokens = 0
58
+ @total_tokens = 0
59
+
51
60
  raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless @thread.is_a?(Langchain::Thread)
52
61
 
53
62
  # The first message in the thread should be the system instructions
54
63
  # TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
55
- if llm.is_a?(Langchain::LLM::OpenAI)
56
- add_message(role: "system", content: instructions) if instructions
57
- end
64
+ initialize_instructions
58
65
  # For Google Gemini, and Anthropic system instructions are added to the `system:` param in the `chat` method
59
66
  end
60
67
 
@@ -150,7 +157,6 @@ module Langchain
150
157
 
151
158
  # Handle the current state and transition to the next state
152
159
  #
153
- # @param state [Symbol] The current state
154
160
  # @return [Symbol] The next state
155
161
  def handle_state
156
162
  case @state
@@ -189,7 +195,6 @@ module Langchain
189
195
 
190
196
  # Handle LLM message scenario
191
197
  #
192
- # @param auto_tool_execution [Boolean] Flag to indicate if tools should be executed automatically
193
198
  # @return [Symbol] The next state
194
199
  def handle_llm_message
195
200
  thread.messages.last.tool_calls.any? ? :requires_action : :completed
@@ -208,14 +213,29 @@ module Langchain
208
213
  # @return [Symbol] The next state
209
214
  def handle_user_or_tool_message
210
215
  response = chat_with_llm
211
- add_message(role: response.role, content: response.chat_completion, tool_calls: response.tool_calls)
212
216
 
217
+ # With Ollama, we're calling the `llm.complete()` method
218
+ content = if llm.is_a?(Langchain::LLM::Ollama)
219
+ response.completion
220
+ else
221
+ response.chat_completion
222
+ end
223
+
224
+ add_message(role: response.role, content: content, tool_calls: response.tool_calls)
225
+ record_used_tokens(response.prompt_tokens, response.completion_tokens, response.total_tokens)
226
+
227
+ set_state_for(response: response)
228
+ end
229
+
230
+ def set_state_for(response:)
213
231
  if response.tool_calls.any?
214
232
  :in_progress
215
233
  elsif response.chat_completion
216
234
  :completed
235
+ elsif response.completion # Currently only used by Ollama
236
+ :completed
217
237
  else
218
- Langchain.logger.error("LLM response does not contain tool calls or chat completion")
238
+ Langchain.logger.error("LLM response does not contain tool calls, chat or completion response")
219
239
  :failed
220
240
  end
221
241
  end
@@ -236,6 +256,8 @@ module Langchain
236
256
  # @return [String] The tool role
237
257
  def determine_tool_role
238
258
  case llm
259
+ when Langchain::LLM::Ollama
260
+ Langchain::Messages::OllamaMessage::TOOL_ROLE
239
261
  when Langchain::LLM::OpenAI
240
262
  Langchain::Messages::OpenAIMessage::TOOL_ROLE
241
263
  when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
@@ -245,31 +267,58 @@ module Langchain
245
267
  end
246
268
  end
247
269
 
270
+ def initialize_instructions
271
+ if llm.is_a?(Langchain::LLM::Ollama)
272
+ content = String.new # rubocop: disable Performance/UnfreezeString
273
+ if tools.any?
274
+ content << %([AVAILABLE_TOOLS] #{tools.map(&:to_openai_tools).flatten}[/AVAILABLE_TOOLS])
275
+ end
276
+ if instructions
277
+ content << "[INST] #{instructions}[/INST]"
278
+ end
279
+
280
+ add_message(role: "system", content: content)
281
+ elsif llm.is_a?(Langchain::LLM::OpenAI)
282
+ add_message(role: "system", content: instructions) if instructions
283
+ end
284
+ end
285
+
248
286
  # Call to the LLM#chat() method
249
287
  #
250
288
  # @return [Langchain::LLM::BaseResponse] The LLM response object
251
289
  def chat_with_llm
252
290
  Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
253
291
 
254
- params = {messages: thread.array_of_message_hashes}
292
+ params = {}
255
293
 
256
- if tools.any?
257
- if llm.is_a?(Langchain::LLM::OpenAI)
294
+ if llm.is_a?(Langchain::LLM::OpenAI)
295
+ if tools.any?
258
296
  params[:tools] = tools.map(&:to_openai_tools).flatten
259
297
  params[:tool_choice] = "auto"
260
- elsif llm.is_a?(Langchain::LLM::Anthropic)
298
+ end
299
+ elsif llm.is_a?(Langchain::LLM::Anthropic)
300
+ if tools.any?
261
301
  params[:tools] = tools.map(&:to_anthropic_tools).flatten
262
- params[:system] = instructions if instructions
263
302
  params[:tool_choice] = {type: "auto"}
264
- elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
303
+ end
304
+ params[:system] = instructions if instructions
305
+ elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
306
+ if tools.any?
265
307
  params[:tools] = tools.map(&:to_google_gemini_tools).flatten
266
308
  params[:system] = instructions if instructions
267
309
  params[:tool_choice] = "auto"
268
310
  end
269
- # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
270
311
  end
312
+ # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
271
313
 
272
- llm.chat(**params)
314
+ if llm.is_a?(Langchain::LLM::Ollama)
315
+ params[:raw] = true
316
+ params[:prompt] = thread.prompt_of_concatenated_messages
317
+ llm.complete(**params)
318
+ else
319
+ params[:messages] = thread.array_of_message_hashes
320
+ llm.chat(**params)
321
+ end
273
322
  end
274
323
 
275
324
  # Run the tools automatically
@@ -278,7 +327,9 @@ module Langchain
278
327
  def run_tools(tool_calls)
279
328
  # Iterate over each function invocation and submit tool output
280
329
  tool_calls.each do |tool_call|
281
- tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
330
+ tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::Ollama)
331
+ extract_ollama_tool_call(tool_call: tool_call)
332
+ elsif llm.is_a?(Langchain::LLM::OpenAI)
282
333
  extract_openai_tool_call(tool_call: tool_call)
283
334
  elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
284
335
  extract_google_gemini_tool_call(tool_call: tool_call)
@@ -296,6 +347,12 @@ module Langchain
296
347
  end
297
348
  end
298
349
 
350
+ def extract_ollama_tool_call(tool_call:)
351
+ tool_name, method_name = tool_call.dig("name").split("__")
352
+ tool_arguments = tool_call.dig("arguments").transform_keys(&:to_sym)
353
+ [nil, tool_name, method_name, tool_arguments]
354
+ end
355
+
299
356
  # Extract the tool call information from the OpenAI tool call hash
300
357
  #
301
358
  # @param tool_call [Hash] The tool call hash
@@ -346,7 +403,9 @@ module Langchain
346
403
  # @param tool_call_id [String] The ID of the tool call to include in the message
347
404
  # @return [Langchain::Message] The Message object
348
405
  def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
349
- if llm.is_a?(Langchain::LLM::OpenAI)
406
+ if llm.is_a?(Langchain::LLM::Ollama)
407
+ Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
408
+ elsif llm.is_a?(Langchain::LLM::OpenAI)
350
409
  Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
351
410
  elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
352
411
  Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
@@ -355,6 +414,18 @@ module Langchain
355
414
  end
356
415
  end
357
416
 
417
+ # Increment the tokens count based on the last interaction with the LLM
418
+ #
419
+ # @param prompt_tokens [Integer] The number of used prmopt tokens
420
+ # @param completion_tokens [Integer] The number of used completion tokens
421
+ # @param total_tokens [Integer] The total number of used tokens
422
+ # @return [Integer] The current total tokens count
423
+ def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_operation)
424
+ @total_prompt_tokens += prompt_tokens if prompt_tokens
425
+ @total_completion_tokens += completion_tokens if completion_tokens
426
+ @total_tokens += total_tokens_from_operation if total_tokens_from_operation
427
+ end
428
+
358
429
  # TODO: Fix the message truncation when context window is exceeded
359
430
  end
360
431
  end
@@ -0,0 +1,86 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class OllamaMessage < Base
6
+ # OpenAI uses the following roles:
7
+ ROLES = [
8
+ "system",
9
+ "assistant",
10
+ "user",
11
+ "tool"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "tool"
15
+
16
+ # Initialize a new OpenAI message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ def to_s
34
+ send(:"to_#{role}_message_string")
35
+ end
36
+
37
+ def to_system_message_string
38
+ content
39
+ end
40
+
41
+ def to_user_message_string
42
+ "[INST] #{content}[/INST]"
43
+ end
44
+
45
+ def to_tool_message_string
46
+ "[TOOL_RESULTS] #{content}[/TOOL_RESULTS]"
47
+ end
48
+
49
+ def to_assistant_message_string
50
+ if tool_calls.any?
51
+ %("[TOOL_CALLS] #{tool_calls}")
52
+ else
53
+ content
54
+ end
55
+ end
56
+
57
+ # Check if the message came from an LLM
58
+ #
59
+ # @return [Boolean] true/false whether this message was produced by an LLM
60
+ def llm?
61
+ assistant?
62
+ end
63
+
64
+ # Check if the message came from an LLM
65
+ #
66
+ # @return [Boolean] true/false whether this message was produced by an LLM
67
+ def assistant?
68
+ role == "assistant"
69
+ end
70
+
71
+ # Check if the message are system instructions
72
+ #
73
+ # @return [Boolean] true/false whether this message are system instructions
74
+ def system?
75
+ role == "system"
76
+ end
77
+
78
+ # Check if the message is a tool call
79
+ #
80
+ # @return [Boolean] true/false whether this message is a tool call
81
+ def tool?
82
+ role == "tool"
83
+ end
84
+ end
85
+ end
86
+ end
@@ -17,7 +17,14 @@ module Langchain
17
17
  #
18
18
  # @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
19
19
  def array_of_message_hashes
20
- messages.map(&:to_hash)
20
+ messages
21
+ .map(&:to_hash)
22
+ .compact
23
+ end
24
+
25
+ # Only used by the Assistant when it calls the LLM#complete() method
26
+ def prompt_of_concatenated_messages
27
+ messages.map(&:to_s).join
21
28
  end
22
29
 
23
30
  # Add a message to the thread
@@ -16,8 +16,6 @@ module Langchain::LLM
16
16
  model: "j2-ultra"
17
17
  }.freeze
18
18
 
19
- LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AI21Validator
20
-
21
19
  def initialize(api_key:, default_options: {})
22
20
  depends_on "ai21"
23
21
 
@@ -35,8 +33,6 @@ module Langchain::LLM
35
33
  def complete(prompt:, **params)
36
34
  parameters = complete_parameters params
37
35
 
38
- parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], {llm: client})
39
-
40
36
  response = client.complete(prompt, parameters)
41
37
  Langchain::LLM::AI21Response.new response, model: parameters[:model]
42
38
  end
@@ -5,10 +5,10 @@ module Langchain::LLM
5
5
  # Wrapper around Anthropic APIs.
6
6
  #
7
7
  # Gem requirements:
8
- # gem "anthropic", "~> 0.1.0"
8
+ # gem "anthropic", "~> 0.3.0"
9
9
  #
10
10
  # Usage:
11
- # anthorpic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
11
+ # anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -18,9 +18,6 @@ module Langchain::LLM
18
18
  max_tokens_to_sample: 256
19
19
  }.freeze
20
20
 
21
- # TODO: Implement token length validator for Anthropic
22
- # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AnthropicValidator
23
-
24
21
  # Initialize an Anthropic LLM instance
25
22
  #
26
23
  # @param api_key [String] The API key to use
@@ -81,7 +78,10 @@ module Langchain::LLM
81
78
  parameters[:metadata] = metadata if metadata
82
79
  parameters[:stream] = stream if stream
83
80
 
84
- response = client.complete(parameters: parameters)
81
+ response = with_api_error_handling do
82
+ client.complete(parameters: parameters)
83
+ end
84
+
85
85
  Langchain::LLM::AnthropicResponse.new(response)
86
86
  end
87
87
 
@@ -114,6 +114,15 @@ module Langchain::LLM
114
114
  Langchain::LLM::AnthropicResponse.new(response)
115
115
  end
116
116
 
117
+ def with_api_error_handling
118
+ response = yield
119
+ return if response.empty?
120
+
121
+ raise Langchain::LLM::ApiError.new "Anthropic API error: #{response.dig("error", "message")}" if response&.dig("error")
122
+
123
+ response
124
+ end
125
+
117
126
  private
118
127
 
119
128
  def set_extra_headers!
@@ -42,17 +42,17 @@ module Langchain::LLM
42
42
 
43
43
  def embed(...)
44
44
  @client = @embed_client
45
- super(...)
45
+ super
46
46
  end
47
47
 
48
48
  def complete(...)
49
49
  @client = @chat_client
50
- super(...)
50
+ super
51
51
  end
52
52
 
53
53
  def chat(...)
54
54
  @client = @chat_client
55
- super(...)
55
+ super
56
56
  end
57
57
  end
58
58
  end
@@ -8,6 +8,7 @@ module Langchain::LLM
8
8
  # Langchain.rb provides a common interface to interact with all supported LLMs:
9
9
  #
10
10
  # - {Langchain::LLM::AI21}
11
+ # - {Langchain::LLM::Anthropic}
11
12
  # - {Langchain::LLM::Azure}
12
13
  # - {Langchain::LLM::Cohere}
13
14
  # - {Langchain::LLM::GooglePalm}
@@ -74,8 +74,6 @@ module Langchain::LLM
74
74
 
75
75
  default_params.merge!(params)
76
76
 
77
- default_params[:max_tokens] = Langchain::Utils::TokenLength::CohereValidator.validate_max_tokens!(prompt, default_params[:model], llm: client)
78
-
79
77
  response = client.generate(**default_params)
80
78
  Langchain::LLM::CohereResponse.new response, model: @defaults[:completion_model_name]
81
79
  end
@@ -18,7 +18,7 @@ module Langchain::LLM
18
18
  chat_completion_model_name: "chat-bison-001",
19
19
  embeddings_model_name: "embedding-gecko-001"
20
20
  }.freeze
21
- LENGTH_VALIDATOR = Langchain::Utils::TokenLength::GooglePalmValidator
21
+
22
22
  ROLE_MAPPING = {
23
23
  "assistant" => "ai"
24
24
  }
@@ -96,9 +96,6 @@ module Langchain::LLM
96
96
  examples: compose_examples(examples)
97
97
  }
98
98
 
99
- # chat-bison-001 is the only model that currently supports countMessageTokens functions
100
- LENGTH_VALIDATOR.validate_max_tokens!(default_params[:messages], "chat-bison-001", llm: self)
101
-
102
99
  if options[:stop_sequences]
103
100
  default_params[:stop] = options.delete(:stop_sequences)
104
101
  end
@@ -14,7 +14,7 @@ module Langchain::LLM
14
14
  attr_reader :url, :defaults
15
15
 
16
16
  DEFAULTS = {
17
- temperature: 0.8,
17
+ temperature: 0.0,
18
18
  completion_model_name: "llama3",
19
19
  embeddings_model_name: "llama3",
20
20
  chat_completion_model_name: "llama3"
@@ -3,7 +3,7 @@
3
3
  module Langchain::LLM
4
4
  class GoogleGeminiResponse < BaseResponse
5
5
  def initialize(raw_response, model: nil)
6
- super(raw_response, model: model)
6
+ super
7
7
  end
8
8
 
9
9
  def chat_completion
@@ -36,7 +36,7 @@ module Langchain::LLM
36
36
  end
37
37
 
38
38
  def prompt_tokens
39
- raw_response.dig("prompt_eval_count") if done?
39
+ raw_response.fetch("prompt_eval_count", 0) if done?
40
40
  end
41
41
 
42
42
  def completion_tokens
@@ -47,6 +47,24 @@ module Langchain::LLM
47
47
  prompt_tokens + completion_tokens if done?
48
48
  end
49
49
 
50
+ def tool_calls
51
+ if chat_completion && (parsed_tool_calls = JSON.parse(chat_completion))
52
+ [parsed_tool_calls]
53
+ elsif completion&.include?("[TOOL_CALLS]") && (
54
+ parsed_tool_calls = JSON.parse(
55
+ completion
56
+ # Slice out the serialize JSON
57
+ .slice(/\{.*\}/)
58
+ # Replace hash rocket with colon
59
+ .gsub("=>", ":")
60
+ )
61
+ )
62
+ [parsed_tool_calls]
63
+ else
64
+ []
65
+ end
66
+ end
67
+
50
68
  private
51
69
 
52
70
  def done?
@@ -140,7 +140,7 @@ module Langchain::Vectorsearch
140
140
 
141
141
  client.search(
142
142
  collection_name: index_name,
143
- output_fields: ["id", "content", "vectors"],
143
+ output_fields: ["id", "content"], # Add "vectors" if need to have full vectors returned.
144
144
  top_k: k.to_s,
145
145
  vectors: [embedding],
146
146
  dsl_type: 1,
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module Langchain
4
- VERSION = "0.13.5"
4
+ VERSION = "0.14.0"
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: langchainrb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.5
4
+ version: 0.14.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrei Bondarev
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2024-07-01 00:00:00.000000000 Z
11
+ date: 2024-07-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: baran
@@ -212,14 +212,14 @@ dependencies:
212
212
  requirements:
213
213
  - - "~>"
214
214
  - !ruby/object:Gem::Version
215
- version: '0.2'
215
+ version: '0.3'
216
216
  type: :development
217
217
  prerelease: false
218
218
  version_requirements: !ruby/object:Gem::Requirement
219
219
  requirements:
220
220
  - - "~>"
221
221
  - !ruby/object:Gem::Version
222
- version: '0.2'
222
+ version: '0.3'
223
223
  - !ruby/object:Gem::Dependency
224
224
  name: aws-sdk-bedrockruntime
225
225
  requirement: !ruby/object:Gem::Requirement
@@ -682,20 +682,6 @@ dependencies:
682
682
  - - "~>"
683
683
  - !ruby/object:Gem::Version
684
684
  version: 0.1.0
685
- - !ruby/object:Gem::Dependency
686
- name: tiktoken_ruby
687
- requirement: !ruby/object:Gem::Requirement
688
- requirements:
689
- - - "~>"
690
- - !ruby/object:Gem::Version
691
- version: 0.0.9
692
- type: :development
693
- prerelease: false
694
- version_requirements: !ruby/object:Gem::Requirement
695
- requirements:
696
- - - "~>"
697
- - !ruby/object:Gem::Version
698
- version: 0.0.9
699
685
  description: Build LLM-backed Ruby applications with Ruby's Langchain.rb
700
686
  email:
701
687
  - andrei.bondarev13@gmail.com
@@ -711,6 +697,7 @@ files:
711
697
  - lib/langchain/assistants/messages/anthropic_message.rb
712
698
  - lib/langchain/assistants/messages/base.rb
713
699
  - lib/langchain/assistants/messages/google_gemini_message.rb
700
+ - lib/langchain/assistants/messages/ollama_message.rb
714
701
  - lib/langchain/assistants/messages/openai_message.rb
715
702
  - lib/langchain/assistants/thread.rb
716
703
  - lib/langchain/chunk.rb
@@ -810,12 +797,6 @@ files:
810
797
  - lib/langchain/tool/wikipedia/wikipedia.rb
811
798
  - lib/langchain/utils/cosine_similarity.rb
812
799
  - lib/langchain/utils/hash_transformer.rb
813
- - lib/langchain/utils/token_length/ai21_validator.rb
814
- - lib/langchain/utils/token_length/base_validator.rb
815
- - lib/langchain/utils/token_length/cohere_validator.rb
816
- - lib/langchain/utils/token_length/google_palm_validator.rb
817
- - lib/langchain/utils/token_length/openai_validator.rb
818
- - lib/langchain/utils/token_length/token_limit_exceeded.rb
819
800
  - lib/langchain/vectorsearch/base.rb
820
801
  - lib/langchain/vectorsearch/chroma.rb
821
802
  - lib/langchain/vectorsearch/elasticsearch.rb
@@ -1,41 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to AI21's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
-
11
- class AI21Validator < BaseValidator
12
- TOKEN_LIMITS = {
13
- "j2-ultra" => 8192,
14
- "j2-mid" => 8192,
15
- "j2-light" => 8192
16
- }.freeze
17
-
18
- #
19
- # Calculate token length for a given text and model name
20
- #
21
- # @param text [String] The text to calculate the token length for
22
- # @param model_name [String] The model name to validate against
23
- # @return [Integer] The token length of the text
24
- #
25
- def self.token_length(text, model_name, options = {})
26
- res = options[:llm].tokenize(text)
27
- res.dig(:tokens).length
28
- end
29
-
30
- def self.token_limit(model_name)
31
- TOKEN_LIMITS[model_name]
32
- end
33
- singleton_class.alias_method :completion_token_limit, :token_limit
34
-
35
- def self.token_length_from_messages(messages, model_name, options)
36
- messages.sum { |message| token_length(message.to_json, model_name, options) }
37
- end
38
- end
39
- end
40
- end
41
- end
@@ -1,42 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # Calculate the `max_tokens:` parameter to be set by calculating the context length of the text minus the prompt length
8
- #
9
- # @param content [String | Array<String>] The text or array of texts to validate
10
- # @param model_name [String] The model name to validate against
11
- # @return [Integer] Whether the text is valid or not
12
- # @raise [TokenLimitExceeded] If the text is too long
13
- #
14
- class BaseValidator
15
- def self.validate_max_tokens!(content, model_name, options = {})
16
- text_token_length = if content.is_a?(Array)
17
- token_length_from_messages(content, model_name, options)
18
- else
19
- token_length(content, model_name, options)
20
- end
21
-
22
- leftover_tokens = token_limit(model_name) - text_token_length
23
-
24
- # Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
25
- # We want the lower of the two limits
26
- max_tokens = [leftover_tokens, completion_token_limit(model_name)].min
27
-
28
- # Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
29
- if max_tokens < 0
30
- raise limit_exceeded_exception(token_limit(model_name), text_token_length)
31
- end
32
-
33
- max_tokens
34
- end
35
-
36
- def self.limit_exceeded_exception(limit, length)
37
- TokenLimitExceeded.new("This model's maximum context length is #{limit} tokens, but the given text is #{length} tokens long.", length - limit)
38
- end
39
- end
40
- end
41
- end
42
- end
@@ -1,49 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to Cohere's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
-
11
- class CohereValidator < BaseValidator
12
- TOKEN_LIMITS = {
13
- # Source:
14
- # https://docs.cohere.com/docs/models
15
- "command-light" => 4096,
16
- "command" => 4096,
17
- "base-light" => 2048,
18
- "base" => 2048,
19
- "embed-english-light-v2.0" => 512,
20
- "embed-english-v2.0" => 512,
21
- "embed-multilingual-v2.0" => 256,
22
- "summarize-medium" => 2048,
23
- "summarize-xlarge" => 2048
24
- }.freeze
25
-
26
- #
27
- # Calculate token length for a given text and model name
28
- #
29
- # @param text [String] The text to calculate the token length for
30
- # @param model_name [String] The model name to validate against
31
- # @return [Integer] The token length of the text
32
- #
33
- def self.token_length(text, model_name, options = {})
34
- res = options[:llm].tokenize(text: text)
35
- res["tokens"].length
36
- end
37
-
38
- def self.token_limit(model_name)
39
- TOKEN_LIMITS[model_name]
40
- end
41
- singleton_class.alias_method :completion_token_limit, :token_limit
42
-
43
- def self.token_length_from_messages(messages, model_name, options)
44
- messages.sum { |message| token_length(message.to_json, model_name, options) }
45
- end
46
- end
47
- end
48
- end
49
- end
@@ -1,57 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- #
7
- # This class is meant to validate the length of the text passed in to Google Palm's API.
8
- # It is used to validate the token length before the API call is made
9
- #
10
- class GooglePalmValidator < BaseValidator
11
- TOKEN_LIMITS = {
12
- # Source:
13
- # This data can be pulled when `list_models()` method is called: https://github.com/andreibondarev/google_palm_api#usage
14
-
15
- # chat-bison-001 is the only model that currently supports countMessageTokens functions
16
- "chat-bison-001" => {
17
- "input_token_limit" => 4000, # 4096 is the limit but the countMessageTokens does not return anything higher than 4000
18
- "output_token_limit" => 1024
19
- }
20
- # "text-bison-001" => {
21
- # "input_token_limit" => 8196,
22
- # "output_token_limit" => 1024
23
- # },
24
- # "embedding-gecko-001" => {
25
- # "input_token_limit" => 1024
26
- # }
27
- }.freeze
28
-
29
- #
30
- # Calculate token length for a given text and model name
31
- #
32
- # @param text [String] The text to calculate the token length for
33
- # @param model_name [String] The model name to validate against
34
- # @param options [Hash] the options to create a message with
35
- # @option options [Langchain::LLM:GooglePalm] :llm The Langchain::LLM:GooglePalm instance
36
- # @return [Integer] The token length of the text
37
- #
38
- def self.token_length(text, model_name = "chat-bison-001", options = {})
39
- response = options[:llm].client.count_message_tokens(model: model_name, prompt: text)
40
-
41
- raise Langchain::LLM::ApiError.new(response["error"]["message"]) unless response["error"].nil?
42
-
43
- response.dig("tokenCount")
44
- end
45
-
46
- def self.token_length_from_messages(messages, model_name, options = {})
47
- messages.sum { |message| token_length(message.to_json, model_name, options) }
48
- end
49
-
50
- def self.token_limit(model_name)
51
- TOKEN_LIMITS.dig(model_name, "input_token_limit")
52
- end
53
- singleton_class.alias_method :completion_token_limit, :token_limit
54
- end
55
- end
56
- end
57
- end
@@ -1,138 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require "tiktoken_ruby"
4
-
5
- module Langchain
6
- module Utils
7
- module TokenLength
8
- #
9
- # This class is meant to validate the length of the text passed in to OpenAI's API.
10
- # It is used to validate the token length before the API call is made
11
- #
12
- class OpenAIValidator < BaseValidator
13
- COMPLETION_TOKEN_LIMITS = {
14
- # GPT-4 Turbo has a separate token limit for completion
15
- # Source:
16
- # https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
17
- "gpt-4-1106-preview" => 4096,
18
- "gpt-4-vision-preview" => 4096,
19
- "gpt-3.5-turbo-1106" => 4096
20
- }
21
-
22
- # NOTE: The gpt-4-turbo-preview is an alias that will always point to the latest GPT 4 Turbo preview
23
- # the future previews may have a different token limit!
24
- TOKEN_LIMITS = {
25
- # Source:
26
- # https://platform.openai.com/docs/api-reference/embeddings
27
- # https://platform.openai.com/docs/models/gpt-4
28
- "text-embedding-3-large" => 8191,
29
- "text-embedding-3-small" => 8191,
30
- "text-embedding-ada-002" => 8191,
31
- "gpt-3.5-turbo" => 16385,
32
- "gpt-3.5-turbo-0301" => 4096,
33
- "gpt-3.5-turbo-0613" => 4096,
34
- "gpt-3.5-turbo-1106" => 16385,
35
- "gpt-3.5-turbo-0125" => 16385,
36
- "gpt-3.5-turbo-16k" => 16384,
37
- "gpt-3.5-turbo-16k-0613" => 16384,
38
- "text-davinci-003" => 4097,
39
- "text-davinci-002" => 4097,
40
- "code-davinci-002" => 8001,
41
- "gpt-4" => 8192,
42
- "gpt-4-0314" => 8192,
43
- "gpt-4-0613" => 8192,
44
- "gpt-4-32k" => 32768,
45
- "gpt-4-32k-0314" => 32768,
46
- "gpt-4-32k-0613" => 32768,
47
- "gpt-4-1106-preview" => 128000,
48
- "gpt-4-turbo" => 128000,
49
- "gpt-4-turbo-2024-04-09" => 128000,
50
- "gpt-4-turbo-preview" => 128000,
51
- "gpt-4-0125-preview" => 128000,
52
- "gpt-4-vision-preview" => 128000,
53
- "gpt-4o" => 128000,
54
- "gpt-4o-2024-05-13" => 128000,
55
- "text-curie-001" => 2049,
56
- "text-babbage-001" => 2049,
57
- "text-ada-001" => 2049,
58
- "davinci" => 2049,
59
- "curie" => 2049,
60
- "babbage" => 2049,
61
- "ada" => 2049
62
- }.freeze
63
-
64
- #
65
- # Calculate token length for a given text and model name
66
- #
67
- # @param text [String] The text to calculate the token length for
68
- # @param model_name [String] The model name to validate against
69
- # @return [Integer] The token length of the text
70
- #
71
- def self.token_length(text, model_name, options = {})
72
- # tiktoken-ruby doesn't support text-embedding-3-large or text-embedding-3-small yet
73
- if ["text-embedding-3-large", "text-embedding-3-small"].include?(model_name)
74
- model_name = "text-embedding-ada-002"
75
- end
76
-
77
- encoder = Tiktoken.encoding_for_model(model_name)
78
- encoder.encode(text).length
79
- end
80
-
81
- def self.token_limit(model_name)
82
- TOKEN_LIMITS[model_name]
83
- end
84
-
85
- def self.completion_token_limit(model_name)
86
- COMPLETION_TOKEN_LIMITS[model_name] || token_limit(model_name)
87
- end
88
-
89
- # If :max_tokens is passed in, take the lower of it and the calculated max_tokens
90
- def self.validate_max_tokens!(content, model_name, options = {})
91
- max_tokens = super(content, model_name, options)
92
- [options[:max_tokens], max_tokens].reject(&:nil?).min
93
- end
94
-
95
- # Copied from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
96
- # Return the number of tokens used by a list of messages
97
- #
98
- # @param messages [Array<Hash>] The messages to calculate the token length for
99
- # @param model [String] The model name to validate against
100
- # @return [Integer] The token length of the messages
101
- #
102
- def self.token_length_from_messages(messages, model_name, options = {})
103
- encoding = Tiktoken.encoding_for_model(model_name)
104
-
105
- if ["gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613"].include?(model_name)
106
- tokens_per_message = 3
107
- tokens_per_name = 1
108
- elsif model_name == "gpt-3.5-turbo-0301"
109
- tokens_per_message = 4 # every message follows {role/name}\n{content}\n
110
- tokens_per_name = -1 # if there's a name, the role is omitted
111
- elsif model_name.include?("gpt-3.5-turbo")
112
- # puts "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
113
- return token_length_from_messages(messages, "gpt-3.5-turbo-0613", options)
114
- elsif model_name.include?("gpt-4")
115
- # puts "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
116
- return token_length_from_messages(messages, "gpt-4-0613", options)
117
- else
118
- raise NotImplementedError.new(
119
- "token_length_from_messages() is not implemented for model #{model_name}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."
120
- )
121
- end
122
-
123
- num_tokens = 0
124
- messages.each do |message|
125
- num_tokens += tokens_per_message
126
- message.each do |key, value|
127
- num_tokens += encoding.encode(value).length
128
- num_tokens += tokens_per_name if ["name", :name].include?(key)
129
- end
130
- end
131
-
132
- num_tokens += 3 # every reply is primed with assistant
133
- num_tokens
134
- end
135
- end
136
- end
137
- end
138
- end
@@ -1,17 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- module Langchain
4
- module Utils
5
- module TokenLength
6
- class TokenLimitExceeded < StandardError
7
- attr_reader :token_overflow
8
-
9
- def initialize(message = "", token_overflow = 0)
10
- super(message)
11
-
12
- @token_overflow = token_overflow
13
- end
14
- end
15
- end
16
- end
17
- end