langchainrb 0.13.5 → 0.15.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (62) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +14 -0
  3. data/README.md +2 -17
  4. data/lib/langchain/assistants/assistant.rb +207 -92
  5. data/lib/langchain/assistants/messages/ollama_message.rb +74 -0
  6. data/lib/langchain/assistants/thread.rb +8 -1
  7. data/lib/langchain/contextual_logger.rb +2 -2
  8. data/lib/langchain/llm/ai21.rb +0 -4
  9. data/lib/langchain/llm/anthropic.rb +15 -6
  10. data/lib/langchain/llm/azure.rb +3 -3
  11. data/lib/langchain/llm/base.rb +1 -0
  12. data/lib/langchain/llm/cohere.rb +0 -2
  13. data/lib/langchain/llm/google_gemini.rb +1 -1
  14. data/lib/langchain/llm/google_palm.rb +1 -4
  15. data/lib/langchain/llm/ollama.rb +24 -18
  16. data/lib/langchain/llm/openai.rb +1 -1
  17. data/lib/langchain/llm/response/google_gemini_response.rb +1 -1
  18. data/lib/langchain/llm/response/ollama_response.rb +5 -1
  19. data/lib/langchain/llm/unified_parameters.rb +2 -2
  20. data/lib/langchain/tool/calculator.rb +38 -0
  21. data/lib/langchain/tool/{database/database.rb → database.rb} +24 -12
  22. data/lib/langchain/tool/file_system.rb +44 -0
  23. data/lib/langchain/tool/{google_search/google_search.rb → google_search.rb} +17 -23
  24. data/lib/langchain/tool/{news_retriever/news_retriever.rb → news_retriever.rb} +41 -14
  25. data/lib/langchain/tool/ruby_code_interpreter.rb +41 -0
  26. data/lib/langchain/tool/{tavily/tavily.rb → tavily.rb} +24 -10
  27. data/lib/langchain/tool/vectorsearch.rb +40 -0
  28. data/lib/langchain/tool/{weather/weather.rb → weather.rb} +21 -17
  29. data/lib/langchain/tool/{wikipedia/wikipedia.rb → wikipedia.rb} +17 -13
  30. data/lib/langchain/tool_definition.rb +212 -0
  31. data/lib/langchain/utils/hash_transformer.rb +9 -17
  32. data/lib/langchain/vectorsearch/chroma.rb +2 -2
  33. data/lib/langchain/vectorsearch/elasticsearch.rb +2 -2
  34. data/lib/langchain/vectorsearch/epsilla.rb +3 -3
  35. data/lib/langchain/vectorsearch/milvus.rb +3 -3
  36. data/lib/langchain/vectorsearch/pgvector.rb +2 -2
  37. data/lib/langchain/vectorsearch/pinecone.rb +2 -2
  38. data/lib/langchain/vectorsearch/qdrant.rb +2 -2
  39. data/lib/langchain/vectorsearch/weaviate.rb +4 -4
  40. data/lib/langchain/version.rb +1 -1
  41. metadata +16 -45
  42. data/lib/langchain/tool/base.rb +0 -107
  43. data/lib/langchain/tool/calculator/calculator.json +0 -19
  44. data/lib/langchain/tool/calculator/calculator.rb +0 -34
  45. data/lib/langchain/tool/database/database.json +0 -46
  46. data/lib/langchain/tool/file_system/file_system.json +0 -57
  47. data/lib/langchain/tool/file_system/file_system.rb +0 -32
  48. data/lib/langchain/tool/google_search/google_search.json +0 -19
  49. data/lib/langchain/tool/news_retriever/news_retriever.json +0 -122
  50. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.json +0 -19
  51. data/lib/langchain/tool/ruby_code_interpreter/ruby_code_interpreter.rb +0 -37
  52. data/lib/langchain/tool/tavily/tavily.json +0 -54
  53. data/lib/langchain/tool/vectorsearch/vectorsearch.json +0 -24
  54. data/lib/langchain/tool/vectorsearch/vectorsearch.rb +0 -36
  55. data/lib/langchain/tool/weather/weather.json +0 -19
  56. data/lib/langchain/tool/wikipedia/wikipedia.json +0 -19
  57. data/lib/langchain/utils/token_length/ai21_validator.rb +0 -41
  58. data/lib/langchain/utils/token_length/base_validator.rb +0 -42
  59. data/lib/langchain/utils/token_length/cohere_validator.rb +0 -49
  60. data/lib/langchain/utils/token_length/google_palm_validator.rb +0 -57
  61. data/lib/langchain/utils/token_length/openai_validator.rb +0 -138
  62. data/lib/langchain/utils/token_length/token_limit_exceeded.rb +0 -17
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d7eac7a6ba7767f6a3f84ee808fa4810eaa1843776695ab0225ddd6b77cf7a73
4
- data.tar.gz: e9f7c0170fc2a8dbf443f1bac24874878ee0fbba7e0495bf65a8df969d3d86e6
3
+ metadata.gz: dde504e05b1cbb32c857569bf71301537fed2deb468f1bdd69a7ef900a41c085
4
+ data.tar.gz: '08659cddd6f0bb285e167c7a35dbd2f83c2e9bb51a69206217ea91649e99839c'
5
5
  SHA512:
6
- metadata.gz: e4d14ac64e54e5c7245a9586dfb4899154793ea466f9564a510eb3dfe17a3a7229cf61e408445b38fec37500065b5e1ee725afa634284bea5538abac0766237f
7
- data.tar.gz: e8fe3e1639a3f2ed087436610dd1653e775703c1c6cc83f7f52eb7d3fb46db554e7be790bc6bc2ddf18ec4e3c26dddbe1ec72e8f25603db1192e5a111d0f9543
6
+ metadata.gz: ce4dd091498659a2d8dda4b54e9e9584dc19be5f390dc5f1d98efa054a264134dc3510f2f83c65bdf23edfbd7344587b91113e69c2ea1fea2cdc157317735799
7
+ data.tar.gz: a6df110aa7d96c87402164f67aadab0a97e2a62b68b7466cf630fe79dd0611a1740ae11163361eef9c98fc816f7ba12d7bfc0aa2225759cc8191f59fead8fcbd
data/CHANGELOG.md CHANGED
@@ -1,5 +1,19 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.15.0] - 2024-08-14
4
+ - Fix Langchain::Assistant when llm is Anthropic
5
+ - Fix GoogleGemini#chat method
6
+ - Langchain::LLM::Weaviate initializer does not require api_key anymore
7
+ - [BREAKING] Langchain::LLM::OpenAI#chat() uses `gpt-4o-mini` by default instead of `gpt-3.5-turbo` previously.
8
+ - [BREAKING] Assistant works with a number of open-source models via Ollama.
9
+ - [BREAKING] Introduce new `Langchain::ToolDefinition` module to define tools. This replaces the previous reliance on subclassing from `Langchain::Tool::Base`.
10
+
11
+ ## [0.14.0] - 2024-07-12
12
+ - Removed TokenLength validators
13
+ - Assistant works with a Mistral LLM now
14
+ - Assistant keeps track of tokens used
15
+ - Misc fixes and improvements
16
+
3
17
  ## [0.13.5] - 2024-07-01
4
18
  - Add Milvus#remove_texts() method
5
19
  - Langchain::Assistant has a `state` now
data/README.md CHANGED
@@ -428,25 +428,10 @@ Assistants are Agent-like objects that leverage helpful instructions, LLMs, tool
428
428
  ```ruby
429
429
  llm = Langchain::LLM::OpenAI.new(api_key: ENV["OPENAI_API_KEY"])
430
430
  ```
431
- 2. Instantiate a Thread. Threads keep track of the messages in the Assistant conversation.
432
- ```ruby
433
- thread = Langchain::Thread.new
434
- ```
435
- You can pass old message from previously using the Assistant:
436
- ```ruby
437
- thread.messages = messages
438
- ```
439
- Messages contain the conversation history and the whole message history is sent to the LLM every time. A Message belongs to 1 of the 4 roles:
440
- * `Message(role: "system")` message usually contains the instructions.
441
- * `Message(role: "user")` messages come from the user.
442
- * `Message(role: "assistant")` messages are produced by the LLM.
443
- * `Message(role: "tool")` messages are sent in response to tool calls with tool outputs.
444
-
445
- 3. Instantiate an Assistant
431
+ 2. Instantiate an Assistant
446
432
  ```ruby
447
433
  assistant = Langchain::Assistant.new(
448
434
  llm: llm,
449
- thread: thread,
450
435
  instructions: "You are a Meteorologist Assistant that is able to pull the weather for any location",
451
436
  tools: [
452
437
  Langchain::Tool::Weather.new(api_key: ENV["OPEN_WEATHER_API_KEY"])
@@ -482,7 +467,7 @@ assistant.add_message_and_run content: "What about Sacramento, CA?", auto_tool_e
482
467
  ### Accessing Thread messages
483
468
  You can access the messages in a Thread by calling `assistant.thread.messages`.
484
469
  ```ruby
485
- assistant.thread.messages
470
+ assistant.messages
486
471
  ```
487
472
 
488
473
  The Assistant checks the context window limits before every request to the LLM and remove oldest thread messages one by one if the context window is exceeded.
@@ -16,15 +16,9 @@ module Langchain
16
16
  def_delegators :thread, :messages, :messages=
17
17
 
18
18
  attr_reader :llm, :thread, :instructions, :state
19
+ attr_reader :total_prompt_tokens, :total_completion_tokens, :total_tokens
19
20
  attr_accessor :tools
20
21
 
21
- SUPPORTED_LLMS = [
22
- Langchain::LLM::Anthropic,
23
- Langchain::LLM::OpenAI,
24
- Langchain::LLM::GoogleGemini,
25
- Langchain::LLM::GoogleVertexAI
26
- ]
27
-
28
22
  # Create a new assistant
29
23
  #
30
24
  # @param llm [Langchain::LLM::Base] LLM instance that the assistant will use
@@ -37,24 +31,26 @@ module Langchain
37
31
  tools: [],
38
32
  instructions: nil
39
33
  )
40
- unless SUPPORTED_LLMS.include?(llm.class)
41
- raise ArgumentError, "Invalid LLM; currently only #{SUPPORTED_LLMS.join(", ")} are supported"
34
+ unless tools.is_a?(Array) && tools.all? { |tool| tool.class.singleton_class.included_modules.include?(Langchain::ToolDefinition) }
35
+ raise ArgumentError, "Tools must be an array of objects extending Langchain::ToolDefinition"
42
36
  end
43
- raise ArgumentError, "Tools must be an array of Langchain::Tool::Base instance(s)" unless tools.is_a?(Array) && tools.all? { |tool| tool.is_a?(Langchain::Tool::Base) }
44
37
 
45
38
  @llm = llm
39
+ @llm_adapter = LLM::Adapter.build(llm)
46
40
  @thread = thread || Langchain::Thread.new
47
41
  @tools = tools
48
42
  @instructions = instructions
49
43
  @state = :ready
50
44
 
45
+ @total_prompt_tokens = 0
46
+ @total_completion_tokens = 0
47
+ @total_tokens = 0
48
+
51
49
  raise ArgumentError, "Thread must be an instance of Langchain::Thread" unless @thread.is_a?(Langchain::Thread)
52
50
 
53
51
  # The first message in the thread should be the system instructions
54
52
  # TODO: What if the user added old messages and the system instructions are already in there? Should this overwrite the existing instructions?
55
- if llm.is_a?(Langchain::LLM::OpenAI)
56
- add_message(role: "system", content: instructions) if instructions
57
- end
53
+ initialize_instructions
58
54
  # For Google Gemini, and Anthropic system instructions are added to the `system:` param in the `chat` method
59
55
  end
60
56
 
@@ -150,7 +146,6 @@ module Langchain
150
146
 
151
147
  # Handle the current state and transition to the next state
152
148
  #
153
- # @param state [Symbol] The current state
154
149
  # @return [Symbol] The next state
155
150
  def handle_state
156
151
  case @state
@@ -189,7 +184,6 @@ module Langchain
189
184
 
190
185
  # Handle LLM message scenario
191
186
  #
192
- # @param auto_tool_execution [Boolean] Flag to indicate if tools should be executed automatically
193
187
  # @return [Symbol] The next state
194
188
  def handle_llm_message
195
189
  thread.messages.last.tool_calls.any? ? :requires_action : :completed
@@ -208,14 +202,22 @@ module Langchain
208
202
  # @return [Symbol] The next state
209
203
  def handle_user_or_tool_message
210
204
  response = chat_with_llm
205
+
211
206
  add_message(role: response.role, content: response.chat_completion, tool_calls: response.tool_calls)
207
+ record_used_tokens(response.prompt_tokens, response.completion_tokens, response.total_tokens)
208
+
209
+ set_state_for(response: response)
210
+ end
212
211
 
212
+ def set_state_for(response:)
213
213
  if response.tool_calls.any?
214
214
  :in_progress
215
215
  elsif response.chat_completion
216
216
  :completed
217
+ elsif response.completion # Currently only used by Ollama
218
+ :completed
217
219
  else
218
- Langchain.logger.error("LLM response does not contain tool calls or chat completion")
220
+ Langchain.logger.error("LLM response does not contain tool calls, chat or completion response")
219
221
  :failed
220
222
  end
221
223
  end
@@ -227,7 +229,7 @@ module Langchain
227
229
  run_tools(thread.messages.last.tool_calls)
228
230
  :in_progress
229
231
  rescue => e
230
- Langchain.logger.error("Error running tools: #{e.message}")
232
+ Langchain.logger.error("Error running tools: #{e.message}; #{e.backtrace.join('\n')}")
231
233
  :failed
232
234
  end
233
235
 
@@ -236,6 +238,8 @@ module Langchain
236
238
  # @return [String] The tool role
237
239
  def determine_tool_role
238
240
  case llm
241
+ when Langchain::LLM::Ollama
242
+ Langchain::Messages::OllamaMessage::TOOL_ROLE
239
243
  when Langchain::LLM::OpenAI
240
244
  Langchain::Messages::OpenAIMessage::TOOL_ROLE
241
245
  when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
@@ -245,31 +249,24 @@ module Langchain
245
249
  end
246
250
  end
247
251
 
252
+ def initialize_instructions
253
+ if llm.is_a?(Langchain::LLM::OpenAI)
254
+ add_message(role: "system", content: instructions) if instructions
255
+ end
256
+ end
257
+
248
258
  # Call to the LLM#chat() method
249
259
  #
250
260
  # @return [Langchain::LLM::BaseResponse] The LLM response object
251
261
  def chat_with_llm
252
262
  Langchain.logger.info("Sending a call to #{llm.class}", for: self.class)
253
263
 
254
- params = {messages: thread.array_of_message_hashes}
255
-
256
- if tools.any?
257
- if llm.is_a?(Langchain::LLM::OpenAI)
258
- params[:tools] = tools.map(&:to_openai_tools).flatten
259
- params[:tool_choice] = "auto"
260
- elsif llm.is_a?(Langchain::LLM::Anthropic)
261
- params[:tools] = tools.map(&:to_anthropic_tools).flatten
262
- params[:system] = instructions if instructions
263
- params[:tool_choice] = {type: "auto"}
264
- elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
265
- params[:tools] = tools.map(&:to_google_gemini_tools).flatten
266
- params[:system] = instructions if instructions
267
- params[:tool_choice] = "auto"
268
- end
269
- # TODO: Not sure that tool_choice should always be "auto"; Maybe we can let the user toggle it.
270
- end
271
-
272
- llm.chat(**params)
264
+ params = @llm_adapter.build_chat_params(
265
+ tools: @tools,
266
+ instructions: @instructions,
267
+ messages: thread.array_of_message_hashes
268
+ )
269
+ @llm.chat(**params)
273
270
  end
274
271
 
275
272
  # Run the tools automatically
@@ -278,16 +275,10 @@ module Langchain
278
275
  def run_tools(tool_calls)
279
276
  # Iterate over each function invocation and submit tool output
280
277
  tool_calls.each do |tool_call|
281
- tool_call_id, tool_name, method_name, tool_arguments = if llm.is_a?(Langchain::LLM::OpenAI)
282
- extract_openai_tool_call(tool_call: tool_call)
283
- elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
284
- extract_google_gemini_tool_call(tool_call: tool_call)
285
- elsif llm.is_a?(Langchain::LLM::Anthropic)
286
- extract_anthropic_tool_call(tool_call: tool_call)
287
- end
278
+ tool_call_id, tool_name, method_name, tool_arguments = @llm_adapter.extract_tool_call_args(tool_call: tool_call)
288
279
 
289
280
  tool_instance = tools.find do |t|
290
- t.name == tool_name
281
+ t.class.tool_name == tool_name
291
282
  end or raise ArgumentError, "Tool not found in assistant.tools"
292
283
 
293
284
  output = tool_instance.send(method_name, **tool_arguments)
@@ -296,65 +287,189 @@ module Langchain
296
287
  end
297
288
  end
298
289
 
299
- # Extract the tool call information from the OpenAI tool call hash
290
+ # Build a message
300
291
  #
301
- # @param tool_call [Hash] The tool call hash
302
- # @return [Array] The tool call information
303
- def extract_openai_tool_call(tool_call:)
304
- tool_call_id = tool_call.dig("id")
305
-
306
- function_name = tool_call.dig("function", "name")
307
- tool_name, method_name = function_name.split("__")
308
- tool_arguments = JSON.parse(tool_call.dig("function", "arguments"), symbolize_names: true)
309
-
310
- [tool_call_id, tool_name, method_name, tool_arguments]
292
+ # @param role [String] The role of the message
293
+ # @param content [String] The content of the message
294
+ # @param tool_calls [Array<Hash>] The tool calls to include in the message
295
+ # @param tool_call_id [String] The ID of the tool call to include in the message
296
+ # @return [Langchain::Message] The Message object
297
+ def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
298
+ @llm_adapter.build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
311
299
  end
312
300
 
313
- # Extract the tool call information from the Anthropic tool call hash
301
+ # Increment the tokens count based on the last interaction with the LLM
314
302
  #
315
- # @param tool_call [Hash] The tool call hash, format: {"type"=>"tool_use", "id"=>"toolu_01TjusbFApEbwKPRWTRwzadR", "name"=>"news_retriever__get_top_headlines", "input"=>{"country"=>"us", "page_size"=>10}}], "stop_reason"=>"tool_use"}
316
- # @return [Array] The tool call information
317
- def extract_anthropic_tool_call(tool_call:)
318
- tool_call_id = tool_call.dig("id")
303
+ # @param prompt_tokens [Integer] The number of used prmopt tokens
304
+ # @param completion_tokens [Integer] The number of used completion tokens
305
+ # @param total_tokens [Integer] The total number of used tokens
306
+ # @return [Integer] The current total tokens count
307
+ def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_operation)
308
+ @total_prompt_tokens += prompt_tokens if prompt_tokens
309
+ @total_completion_tokens += completion_tokens if completion_tokens
310
+ @total_tokens += total_tokens_from_operation if total_tokens_from_operation
311
+ end
319
312
 
320
- function_name = tool_call.dig("name")
321
- tool_name, method_name = function_name.split("__")
322
- tool_arguments = tool_call.dig("input").transform_keys(&:to_sym)
313
+ # TODO: Fix the message truncation when context window is exceeded
323
314
 
324
- [tool_call_id, tool_name, method_name, tool_arguments]
325
- end
315
+ module LLM
316
+ class Adapter
317
+ def self.build(llm)
318
+ case llm
319
+ when Langchain::LLM::Ollama
320
+ Adapters::Ollama.new
321
+ when Langchain::LLM::OpenAI
322
+ Adapters::OpenAI.new
323
+ when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
324
+ Adapters::GoogleGemini.new
325
+ when Langchain::LLM::Anthropic
326
+ Adapters::Anthropic.new
327
+ else
328
+ raise ArgumentError, "Unsupported LLM type: #{llm.class}"
329
+ end
330
+ end
331
+ end
326
332
 
327
- # Extract the tool call information from the Google Gemini tool call hash
328
- #
329
- # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
330
- # @return [Array] The tool call information
331
- def extract_google_gemini_tool_call(tool_call:)
332
- tool_call_id = tool_call.dig("functionCall", "name")
333
+ module Adapters
334
+ class Base
335
+ def build_chat_params(tools:, instructions:, messages:)
336
+ raise NotImplementedError, "Subclasses must implement build_chat_params"
337
+ end
333
338
 
334
- function_name = tool_call.dig("functionCall", "name")
335
- tool_name, method_name = function_name.split("__")
336
- tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
339
+ def extract_tool_call_args(tool_call:)
340
+ raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
341
+ end
337
342
 
338
- [tool_call_id, tool_name, method_name, tool_arguments]
339
- end
343
+ def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
344
+ raise NotImplementedError, "Subclasses must implement build_message"
345
+ end
346
+ end
340
347
 
341
- # Build a message
342
- #
343
- # @param role [String] The role of the message
344
- # @param content [String] The content of the message
345
- # @param tool_calls [Array<Hash>] The tool calls to include in the message
346
- # @param tool_call_id [String] The ID of the tool call to include in the message
347
- # @return [Langchain::Message] The Message object
348
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
349
- if llm.is_a?(Langchain::LLM::OpenAI)
350
- Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
351
- elsif [Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI].include?(llm.class)
352
- Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
353
- elsif llm.is_a?(Langchain::LLM::Anthropic)
354
- Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
348
+ class Ollama < Base
349
+ def build_chat_params(tools:, instructions:, messages:)
350
+ params = {messages: messages}
351
+ if tools.any?
352
+ params[:tools] = tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
353
+ end
354
+ params
355
+ end
356
+
357
+ def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
358
+ Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
359
+ end
360
+
361
+ # Extract the tool call information from the OpenAI tool call hash
362
+ #
363
+ # @param tool_call [Hash] The tool call hash
364
+ # @return [Array] The tool call information
365
+ def extract_tool_call_args(tool_call:)
366
+ tool_call_id = tool_call.dig("id")
367
+
368
+ function_name = tool_call.dig("function", "name")
369
+ tool_name, method_name = function_name.split("__")
370
+
371
+ tool_arguments = tool_call.dig("function", "arguments")
372
+ tool_arguments = if tool_arguments.is_a?(Hash)
373
+ Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
374
+ else
375
+ JSON.parse(tool_arguments, symbolize_names: true)
376
+ end
377
+
378
+ [tool_call_id, tool_name, method_name, tool_arguments]
379
+ end
380
+ end
381
+
382
+ class OpenAI < Base
383
+ def build_chat_params(tools:, instructions:, messages:)
384
+ params = {messages: messages}
385
+ if tools.any?
386
+ params[:tools] = tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
387
+ params[:tool_choice] = "auto"
388
+ end
389
+ params
390
+ end
391
+
392
+ def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
393
+ Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
394
+ end
395
+
396
+ # Extract the tool call information from the OpenAI tool call hash
397
+ #
398
+ # @param tool_call [Hash] The tool call hash
399
+ # @return [Array] The tool call information
400
+ def extract_tool_call_args(tool_call:)
401
+ tool_call_id = tool_call.dig("id")
402
+
403
+ function_name = tool_call.dig("function", "name")
404
+ tool_name, method_name = function_name.split("__")
405
+
406
+ tool_arguments = tool_call.dig("function", "arguments")
407
+ tool_arguments = if tool_arguments.is_a?(Hash)
408
+ Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
409
+ else
410
+ JSON.parse(tool_arguments, symbolize_names: true)
411
+ end
412
+
413
+ [tool_call_id, tool_name, method_name, tool_arguments]
414
+ end
415
+ end
416
+
417
+ class GoogleGemini < Base
418
+ def build_chat_params(tools:, instructions:, messages:)
419
+ params = {messages: messages}
420
+ if tools.any?
421
+ params[:tools] = tools.map { |tool| tool.class.function_schemas.to_google_gemini_format }.flatten
422
+ params[:system] = instructions if instructions
423
+ params[:tool_choice] = "auto"
424
+ end
425
+ params
426
+ end
427
+
428
+ def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
429
+ Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
430
+ end
431
+
432
+ # Extract the tool call information from the Google Gemini tool call hash
433
+ #
434
+ # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
435
+ # @return [Array] The tool call information
436
+ def extract_tool_call_args(tool_call:)
437
+ tool_call_id = tool_call.dig("functionCall", "name")
438
+ function_name = tool_call.dig("functionCall", "name")
439
+ tool_name, method_name = function_name.split("__")
440
+ tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
441
+ [tool_call_id, tool_name, method_name, tool_arguments]
442
+ end
443
+ end
444
+
445
+ class Anthropic < Base
446
+ def build_chat_params(tools:, instructions:, messages:)
447
+ params = {messages: messages}
448
+ if tools.any?
449
+ params[:tools] = tools.map { |tool| tool.class.function_schemas.to_anthropic_format }.flatten
450
+ params[:tool_choice] = {type: "auto"}
451
+ end
452
+ params[:system] = instructions if instructions
453
+ params
454
+ end
455
+
456
+ def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
457
+ Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
458
+ end
459
+
460
+ # Extract the tool call information from the Anthropic tool call hash
461
+ #
462
+ # @param tool_call [Hash] The tool call hash, format: {"type"=>"tool_use", "id"=>"toolu_01TjusbFApEbwKPRWTRwzadR", "name"=>"news_retriever__get_top_headlines", "input"=>{"country"=>"us", "page_size"=>10}}], "stop_reason"=>"tool_use"}
463
+ # @return [Array] The tool call information
464
+ def extract_tool_call_args(tool_call:)
465
+ tool_call_id = tool_call.dig("id")
466
+ function_name = tool_call.dig("name")
467
+ tool_name, method_name = function_name.split("__")
468
+ tool_arguments = tool_call.dig("input").transform_keys(&:to_sym)
469
+ [tool_call_id, tool_name, method_name, tool_arguments]
470
+ end
471
+ end
355
472
  end
356
473
  end
357
-
358
- # TODO: Fix the message truncation when context window is exceeded
359
474
  end
360
475
  end
@@ -0,0 +1,74 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Langchain
4
+ module Messages
5
+ class OllamaMessage < Base
6
+ # OpenAI uses the following roles:
7
+ ROLES = [
8
+ "system",
9
+ "assistant",
10
+ "user",
11
+ "tool"
12
+ ].freeze
13
+
14
+ TOOL_ROLE = "tool"
15
+
16
+ # Initialize a new OpenAI message
17
+ #
18
+ # @param [String] The role of the message
19
+ # @param [String] The content of the message
20
+ # @param [Array<Hash>] The tool calls made in the message
21
+ # @param [String] The ID of the tool call
22
+ def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil)
23
+ raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
24
+ raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
25
+
26
+ @role = role
27
+ # Some Tools return content as a JSON hence `.to_s`
28
+ @content = content.to_s
29
+ @tool_calls = tool_calls
30
+ @tool_call_id = tool_call_id
31
+ end
32
+
33
+ # Convert the message to an OpenAI API-compatible hash
34
+ #
35
+ # @return [Hash] The message as an OpenAI API-compatible hash
36
+ def to_hash
37
+ {}.tap do |h|
38
+ h[:role] = role
39
+ h[:content] = content if content # Content is nil for tool calls
40
+ h[:tool_calls] = tool_calls if tool_calls.any?
41
+ h[:tool_call_id] = tool_call_id if tool_call_id
42
+ end
43
+ end
44
+
45
+ # Check if the message came from an LLM
46
+ #
47
+ # @return [Boolean] true/false whether this message was produced by an LLM
48
+ def llm?
49
+ assistant?
50
+ end
51
+
52
+ # Check if the message came from an LLM
53
+ #
54
+ # @return [Boolean] true/false whether this message was produced by an LLM
55
+ def assistant?
56
+ role == "assistant"
57
+ end
58
+
59
+ # Check if the message are system instructions
60
+ #
61
+ # @return [Boolean] true/false whether this message are system instructions
62
+ def system?
63
+ role == "system"
64
+ end
65
+
66
+ # Check if the message is a tool call
67
+ #
68
+ # @return [Boolean] true/false whether this message is a tool call
69
+ def tool?
70
+ role == "tool"
71
+ end
72
+ end
73
+ end
74
+ end
@@ -17,7 +17,14 @@ module Langchain
17
17
  #
18
18
  # @return [Array<Hash>] The thread as an OpenAI API-compatible array of hashes
19
19
  def array_of_message_hashes
20
- messages.map(&:to_hash)
20
+ messages
21
+ .map(&:to_hash)
22
+ .compact
23
+ end
24
+
25
+ # Only used by the Assistant when it calls the LLM#complete() method
26
+ def prompt_of_concatenated_messages
27
+ messages.map(&:to_s).join
21
28
  end
22
29
 
23
30
  # Add a message to the thread
@@ -35,8 +35,8 @@ module Langchain
35
35
  @logger.respond_to?(method, include_private)
36
36
  end
37
37
 
38
- def method_missing(method, *args, **kwargs, &)
39
- return @logger.send(method, *args, **kwargs, &) unless @levels.include?(method)
38
+ def method_missing(method, *args, **kwargs, &block)
39
+ return @logger.send(method, *args, **kwargs, &block) unless @levels.include?(method)
40
40
 
41
41
  for_class = kwargs.delete(:for)
42
42
  for_class_name = for_class&.name
@@ -16,8 +16,6 @@ module Langchain::LLM
16
16
  model: "j2-ultra"
17
17
  }.freeze
18
18
 
19
- LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AI21Validator
20
-
21
19
  def initialize(api_key:, default_options: {})
22
20
  depends_on "ai21"
23
21
 
@@ -35,8 +33,6 @@ module Langchain::LLM
35
33
  def complete(prompt:, **params)
36
34
  parameters = complete_parameters params
37
35
 
38
- parameters[:maxTokens] = LENGTH_VALIDATOR.validate_max_tokens!(prompt, parameters[:model], {llm: client})
39
-
40
36
  response = client.complete(prompt, parameters)
41
37
  Langchain::LLM::AI21Response.new response, model: parameters[:model]
42
38
  end
@@ -5,10 +5,10 @@ module Langchain::LLM
5
5
  # Wrapper around Anthropic APIs.
6
6
  #
7
7
  # Gem requirements:
8
- # gem "anthropic", "~> 0.1.0"
8
+ # gem "anthropic", "~> 0.3.0"
9
9
  #
10
10
  # Usage:
11
- # anthorpic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
11
+ # anthropic = Langchain::LLM::Anthropic.new(api_key: ENV["ANTHROPIC_API_KEY"])
12
12
  #
13
13
  class Anthropic < Base
14
14
  DEFAULTS = {
@@ -18,9 +18,6 @@ module Langchain::LLM
18
18
  max_tokens_to_sample: 256
19
19
  }.freeze
20
20
 
21
- # TODO: Implement token length validator for Anthropic
22
- # LENGTH_VALIDATOR = Langchain::Utils::TokenLength::AnthropicValidator
23
-
24
21
  # Initialize an Anthropic LLM instance
25
22
  #
26
23
  # @param api_key [String] The API key to use
@@ -81,7 +78,10 @@ module Langchain::LLM
81
78
  parameters[:metadata] = metadata if metadata
82
79
  parameters[:stream] = stream if stream
83
80
 
84
- response = client.complete(parameters: parameters)
81
+ response = with_api_error_handling do
82
+ client.complete(parameters: parameters)
83
+ end
84
+
85
85
  Langchain::LLM::AnthropicResponse.new(response)
86
86
  end
87
87
 
@@ -114,6 +114,15 @@ module Langchain::LLM
114
114
  Langchain::LLM::AnthropicResponse.new(response)
115
115
  end
116
116
 
117
+ def with_api_error_handling
118
+ response = yield
119
+ return if response.empty?
120
+
121
+ raise Langchain::LLM::ApiError.new "Anthropic API error: #{response.dig("error", "message")}" if response&.dig("error")
122
+
123
+ response
124
+ end
125
+
117
126
  private
118
127
 
119
128
  def set_extra_headers!
@@ -42,17 +42,17 @@ module Langchain::LLM
42
42
 
43
43
  def embed(...)
44
44
  @client = @embed_client
45
- super(...)
45
+ super
46
46
  end
47
47
 
48
48
  def complete(...)
49
49
  @client = @chat_client
50
- super(...)
50
+ super
51
51
  end
52
52
 
53
53
  def chat(...)
54
54
  @client = @chat_client
55
- super(...)
55
+ super
56
56
  end
57
57
  end
58
58
  end