langchainrb 0.16.1 → 0.17.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b078089a99e9e8d6654a244165ecc9d0f3dfdd8fbc0367623d41fe771a98ac41
4
- data.tar.gz: 890c371564ce9188087bed9eb053a59e11f7b734a44b9f753696f8458f8a7b7e
3
+ metadata.gz: f7061ef2090d35626239ca575b60edb291dbbadab7de85a5a2796792e1691437
4
+ data.tar.gz: 30cb1f14b602a22e7df8f2dba42660383d44482cbe83fb35dc9539afa836739c
5
5
  SHA512:
6
- metadata.gz: 8f458bfae5af31190f41661a13c24e5cd63d5f88e594e854ee79ea3b8af1f51b20552f178c89c71e454a86e4d827b3facecd97f0fa3ef107b7b6097754fab5e3
7
- data.tar.gz: cfe0c684f89c5eef73ceb26b70292fe8fc4f941e13795ac98bd1d3197321a1303250d2584f9e45aa530e311304004911ffe3a6af7f606f6a733baad21ff2b814
6
+ metadata.gz: dd08fb29bd0ff9237cc27980c3bac607baeb9d54a93f297b1e81fb863b7cbb9720db4adacb3dae92bcbe71d2eb59b38d4de0ee321face467a0b82bde627d2929
7
+ data.tar.gz: 4a3661afd2d9d75a64e02f6f173cd0bf0e016207c444ca4506bab907f00dc906f5bbf82c96aad9521280265f214b0f3e82dd5e4ee54dc40f3afb415a6f50b365
data/CHANGELOG.md CHANGED
@@ -1,5 +1,16 @@
1
1
  ## [Unreleased]
2
2
 
3
+ ## [0.17.1] - 2024-10-07
4
+ - Move Langchain::Assistant::LLM::Adapter-related classes to separate files
5
+ - Fix Langchain::Tool::Database#describe_table method
6
+
7
+ ## [0.17.0] - 2024-10-02
8
+ - [BREAKING] Langchain::Vectorsearch::Milvus was rewritten to work with newer milvus 0.10.0 gem
9
+ - [BREAKING] Removing Langchain::LLM::GooglePalm
10
+ - Assistant can now process image_urls in the messages (currently only for OpenAI and Mistral AI)
11
+ - Vectorsearch providers utilize the global Langchain.logger
12
+ - Update required milvus, qdrant and weaviate versions
13
+
3
14
  ## [0.16.1] - 2024-09-30
4
15
  - Deprecate Langchain::LLM::GooglePalm
5
16
  - Allow setting response_object: {} parameter when initializing supported Langchain::LLM::* classes
data/README.md CHANGED
@@ -63,7 +63,6 @@ The `Langchain::LLM` module provides a unified interface for interacting with va
63
63
  - Azure OpenAI
64
64
  - Cohere
65
65
  - Google Gemini
66
- - Google PaLM (deprecated)
67
66
  - Google Vertex AI
68
67
  - HuggingFace
69
68
  - LlamaCpp
@@ -501,6 +500,12 @@ assistant = Langchain::Assistant.new(
501
500
  # Add a user message and run the assistant
502
501
  assistant.add_message_and_run!(content: "What's the latest news about AI?")
503
502
 
503
+ # Supply an image to the assistant
504
+ assistant.add_message_and_run!(
505
+ content: "Show me a picture of a cat",
506
+ image: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
507
+ )
508
+
504
509
  # Access the conversation thread
505
510
  messages = assistant.messages
506
511
 
@@ -1,5 +1,7 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require_relative "llm/adapter"
4
+
3
5
  module Langchain
4
6
  # Assistants are Agent-like objects that leverage helpful instructions, LLMs, tools and knowledge to respond to user queries.
5
7
  # Assistants can be configured with an LLM of your choice, any vector search database and easily extended with additional tools.
@@ -63,13 +65,14 @@ module Langchain
63
65
 
64
66
  # Add a user message to the messages array
65
67
  #
66
- # @param content [String] The content of the message
67
68
  # @param role [String] The role attribute of the message. Default: "user"
69
+ # @param content [String] The content of the message
70
+ # @param image_url [String] The URL of the image to include in the message
68
71
  # @param tool_calls [Array<Hash>] The tool calls to include in the message
69
72
  # @param tool_call_id [String] The ID of the tool call to include in the message
70
73
  # @return [Array<Langchain::Message>] The messages
71
- def add_message(content: nil, role: "user", tool_calls: [], tool_call_id: nil)
72
- message = build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
74
+ def add_message(role: "user", content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
75
+ message = build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
73
76
 
74
77
  # Call the callback with the message
75
78
  add_message_callback.call(message) if add_message_callback # rubocop:disable Style/SafeNavigation
@@ -145,8 +148,8 @@ module Langchain
145
148
  # @param content [String] The content of the message
146
149
  # @param auto_tool_execution [Boolean] Whether or not to automatically run tools
147
150
  # @return [Array<Langchain::Message>] The messages
148
- def add_message_and_run(content:, auto_tool_execution: false)
149
- add_message(content: content, role: "user")
151
+ def add_message_and_run(content: nil, image_url: nil, auto_tool_execution: false)
152
+ add_message(content: content, image_url: image_url, role: "user")
150
153
  run(auto_tool_execution: auto_tool_execution)
151
154
  end
152
155
 
@@ -154,8 +157,8 @@ module Langchain
154
157
  #
155
158
  # @param content [String] The content of the message
156
159
  # @return [Array<Langchain::Message>] The messages
157
- def add_message_and_run!(content:)
158
- add_message_and_run(content: content, auto_tool_execution: true)
160
+ def add_message_and_run!(content: nil, image_url: nil)
161
+ add_message_and_run(content: content, image_url: image_url, auto_tool_execution: true)
159
162
  end
160
163
 
161
164
  # Submit tool output
@@ -388,11 +391,12 @@ module Langchain
388
391
  #
389
392
  # @param role [String] The role of the message
390
393
  # @param content [String] The content of the message
394
+ # @param image_url [String] The URL of the image to include in the message
391
395
  # @param tool_calls [Array<Hash>] The tool calls to include in the message
392
396
  # @param tool_call_id [String] The ID of the tool call to include in the message
393
397
  # @return [Langchain::Message] The Message object
394
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
395
- @llm_adapter.build_message(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
398
+ def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
399
+ @llm_adapter.build_message(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
396
400
  end
397
401
 
398
402
  # Increment the tokens count based on the last interaction with the LLM
@@ -410,314 +414,5 @@ module Langchain
410
414
  def available_tool_names
411
415
  llm_adapter.available_tool_names(tools)
412
416
  end
413
-
414
- # TODO: Fix the message truncation when context window is exceeded
415
-
416
- module LLM
417
- class Adapter
418
- def self.build(llm)
419
- case llm
420
- when Langchain::LLM::Anthropic
421
- Adapters::Anthropic.new
422
- when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
423
- Adapters::GoogleGemini.new
424
- when Langchain::LLM::MistralAI
425
- Adapters::MistralAI.new
426
- when Langchain::LLM::Ollama
427
- Adapters::Ollama.new
428
- when Langchain::LLM::OpenAI
429
- Adapters::OpenAI.new
430
- else
431
- raise ArgumentError, "Unsupported LLM type: #{llm.class}"
432
- end
433
- end
434
- end
435
-
436
- module Adapters
437
- class Base
438
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
439
- raise NotImplementedError, "Subclasses must implement build_chat_params"
440
- end
441
-
442
- def extract_tool_call_args(tool_call:)
443
- raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
444
- end
445
-
446
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
447
- raise NotImplementedError, "Subclasses must implement build_message"
448
- end
449
- end
450
-
451
- class Ollama < Base
452
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
453
- params = {messages: messages}
454
- if tools.any?
455
- params[:tools] = build_tools(tools)
456
- end
457
- params
458
- end
459
-
460
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
461
- Langchain::Messages::OllamaMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
462
- end
463
-
464
- # Extract the tool call information from the OpenAI tool call hash
465
- #
466
- # @param tool_call [Hash] The tool call hash
467
- # @return [Array] The tool call information
468
- def extract_tool_call_args(tool_call:)
469
- tool_call_id = tool_call.dig("id")
470
-
471
- function_name = tool_call.dig("function", "name")
472
- tool_name, method_name = function_name.split("__")
473
-
474
- tool_arguments = tool_call.dig("function", "arguments")
475
- tool_arguments = if tool_arguments.is_a?(Hash)
476
- Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
477
- else
478
- JSON.parse(tool_arguments, symbolize_names: true)
479
- end
480
-
481
- [tool_call_id, tool_name, method_name, tool_arguments]
482
- end
483
-
484
- def available_tool_names(tools)
485
- build_tools(tools).map { |tool| tool.dig(:function, :name) }
486
- end
487
-
488
- def allowed_tool_choices
489
- ["auto", "none"]
490
- end
491
-
492
- private
493
-
494
- def build_tools(tools)
495
- tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
496
- end
497
- end
498
-
499
- class OpenAI < Base
500
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
501
- params = {messages: messages}
502
- if tools.any?
503
- params[:tools] = build_tools(tools)
504
- params[:tool_choice] = build_tool_choice(tool_choice)
505
- end
506
- params
507
- end
508
-
509
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
510
- Langchain::Messages::OpenAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
511
- end
512
-
513
- # Extract the tool call information from the OpenAI tool call hash
514
- #
515
- # @param tool_call [Hash] The tool call hash
516
- # @return [Array] The tool call information
517
- def extract_tool_call_args(tool_call:)
518
- tool_call_id = tool_call.dig("id")
519
-
520
- function_name = tool_call.dig("function", "name")
521
- tool_name, method_name = function_name.split("__")
522
-
523
- tool_arguments = tool_call.dig("function", "arguments")
524
- tool_arguments = if tool_arguments.is_a?(Hash)
525
- Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
526
- else
527
- JSON.parse(tool_arguments, symbolize_names: true)
528
- end
529
-
530
- [tool_call_id, tool_name, method_name, tool_arguments]
531
- end
532
-
533
- def build_tools(tools)
534
- tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
535
- end
536
-
537
- def allowed_tool_choices
538
- ["auto", "none"]
539
- end
540
-
541
- def available_tool_names(tools)
542
- build_tools(tools).map { |tool| tool.dig(:function, :name) }
543
- end
544
-
545
- private
546
-
547
- def build_tool_choice(choice)
548
- case choice
549
- when "auto"
550
- choice
551
- else
552
- {"type" => "function", "function" => {"name" => choice}}
553
- end
554
- end
555
- end
556
-
557
- class MistralAI < Base
558
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
559
- params = {messages: messages}
560
- if tools.any?
561
- params[:tools] = build_tools(tools)
562
- params[:tool_choice] = build_tool_choice(tool_choice)
563
- end
564
- params
565
- end
566
-
567
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
568
- Langchain::Messages::MistralAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
569
- end
570
-
571
- # Extract the tool call information from the OpenAI tool call hash
572
- #
573
- # @param tool_call [Hash] The tool call hash
574
- # @return [Array] The tool call information
575
- def extract_tool_call_args(tool_call:)
576
- tool_call_id = tool_call.dig("id")
577
-
578
- function_name = tool_call.dig("function", "name")
579
- tool_name, method_name = function_name.split("__")
580
-
581
- tool_arguments = tool_call.dig("function", "arguments")
582
- tool_arguments = if tool_arguments.is_a?(Hash)
583
- Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
584
- else
585
- JSON.parse(tool_arguments, symbolize_names: true)
586
- end
587
-
588
- [tool_call_id, tool_name, method_name, tool_arguments]
589
- end
590
-
591
- def build_tools(tools)
592
- tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
593
- end
594
-
595
- def allowed_tool_choices
596
- ["auto", "none"]
597
- end
598
-
599
- def available_tool_names(tools)
600
- build_tools(tools).map { |tool| tool.dig(:function, :name) }
601
- end
602
-
603
- private
604
-
605
- def build_tool_choice(choice)
606
- case choice
607
- when "auto"
608
- choice
609
- else
610
- {"type" => "function", "function" => {"name" => choice}}
611
- end
612
- end
613
- end
614
-
615
- class GoogleGemini < Base
616
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
617
- params = {messages: messages}
618
- if tools.any?
619
- params[:tools] = build_tools(tools)
620
- params[:system] = instructions if instructions
621
- params[:tool_choice] = build_tool_config(tool_choice)
622
- end
623
- params
624
- end
625
-
626
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
627
- Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
628
- end
629
-
630
- # Extract the tool call information from the Google Gemini tool call hash
631
- #
632
- # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
633
- # @return [Array] The tool call information
634
- def extract_tool_call_args(tool_call:)
635
- tool_call_id = tool_call.dig("functionCall", "name")
636
- function_name = tool_call.dig("functionCall", "name")
637
- tool_name, method_name = function_name.split("__")
638
- tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
639
- [tool_call_id, tool_name, method_name, tool_arguments]
640
- end
641
-
642
- def build_tools(tools)
643
- tools.map { |tool| tool.class.function_schemas.to_google_gemini_format }.flatten
644
- end
645
-
646
- def allowed_tool_choices
647
- ["auto", "none"]
648
- end
649
-
650
- def available_tool_names(tools)
651
- build_tools(tools).map { |tool| tool.dig(:name) }
652
- end
653
-
654
- private
655
-
656
- def build_tool_config(choice)
657
- case choice
658
- when "auto"
659
- {function_calling_config: {mode: "auto"}}
660
- when "none"
661
- {function_calling_config: {mode: "none"}}
662
- else
663
- {function_calling_config: {mode: "any", allowed_function_names: [choice]}}
664
- end
665
- end
666
- end
667
-
668
- class Anthropic < Base
669
- def build_chat_params(tools:, instructions:, messages:, tool_choice:)
670
- params = {messages: messages}
671
- if tools.any?
672
- params[:tools] = build_tools(tools)
673
- params[:tool_choice] = build_tool_choice(tool_choice)
674
- end
675
- params[:system] = instructions if instructions
676
- params
677
- end
678
-
679
- def build_message(role:, content: nil, tool_calls: [], tool_call_id: nil)
680
- Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
681
- end
682
-
683
- # Extract the tool call information from the Anthropic tool call hash
684
- #
685
- # @param tool_call [Hash] The tool call hash, format: {"type"=>"tool_use", "id"=>"toolu_01TjusbFApEbwKPRWTRwzadR", "name"=>"news_retriever__get_top_headlines", "input"=>{"country"=>"us", "page_size"=>10}}], "stop_reason"=>"tool_use"}
686
- # @return [Array] The tool call information
687
- def extract_tool_call_args(tool_call:)
688
- tool_call_id = tool_call.dig("id")
689
- function_name = tool_call.dig("name")
690
- tool_name, method_name = function_name.split("__")
691
- tool_arguments = tool_call.dig("input").transform_keys(&:to_sym)
692
- [tool_call_id, tool_name, method_name, tool_arguments]
693
- end
694
-
695
- def build_tools(tools)
696
- tools.map { |tool| tool.class.function_schemas.to_anthropic_format }.flatten
697
- end
698
-
699
- def allowed_tool_choices
700
- ["auto", "any"]
701
- end
702
-
703
- def available_tool_names(tools)
704
- build_tools(tools).map { |tool| tool.dig(:name) }
705
- end
706
-
707
- private
708
-
709
- def build_tool_choice(choice)
710
- case choice
711
- when "auto"
712
- {type: "auto"}
713
- when "any"
714
- {type: "any"}
715
- else
716
- {type: "tool", name: choice}
717
- end
718
- end
719
- end
720
- end
721
- end
722
417
  end
723
418
  end
@@ -0,0 +1,27 @@
1
+ Dir[Pathname.new(__FILE__).dirname.join("adapters", "*.rb")].sort.each { |file| require file }
2
+
3
+ module Langchain
4
+ class Assistant
5
+ module LLM
6
+ # TODO: Fix the message truncation when context window is exceeded
7
+ class Adapter
8
+ def self.build(llm)
9
+ case llm
10
+ when Langchain::LLM::Anthropic
11
+ LLM::Adapters::Anthropic.new
12
+ when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
13
+ LLM::Adapters::GoogleGemini.new
14
+ when Langchain::LLM::MistralAI
15
+ LLM::Adapters::MistralAI.new
16
+ when Langchain::LLM::Ollama
17
+ LLM::Adapters::Ollama.new
18
+ when Langchain::LLM::OpenAI
19
+ LLM::Adapters::OpenAI.new
20
+ else
21
+ raise ArgumentError, "Unsupported LLM type: #{llm.class}"
22
+ end
23
+ end
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,21 @@
1
+ module Langchain
2
+ class Assistant
3
+ module LLM
4
+ module Adapters
5
+ class Base
6
+ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
+ raise NotImplementedError, "Subclasses must implement build_chat_params"
8
+ end
9
+
10
+ def extract_tool_call_args(tool_call:)
11
+ raise NotImplementedError, "Subclasses must implement extract_tool_call_args"
12
+ end
13
+
14
+ def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
15
+ raise NotImplementedError, "Subclasses must implement build_message"
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,62 @@
1
+ module Langchain
2
+ class Assistant
3
+ module LLM
4
+ module Adapters
5
+ class Anthropic < Base
6
+ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
+ params = {messages: messages}
8
+ if tools.any?
9
+ params[:tools] = build_tools(tools)
10
+ params[:tool_choice] = build_tool_choice(tool_choice)
11
+ end
12
+ params[:system] = instructions if instructions
13
+ params
14
+ end
15
+
16
+ def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
17
+ warn "Image URL is not supported by Anthropic currently" if image_url
18
+
19
+ Langchain::Messages::AnthropicMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
20
+ end
21
+
22
+ # Extract the tool call information from the Anthropic tool call hash
23
+ #
24
+ # @param tool_call [Hash] The tool call hash, format: {"type"=>"tool_use", "id"=>"toolu_01TjusbFApEbwKPRWTRwzadR", "name"=>"news_retriever__get_top_headlines", "input"=>{"country"=>"us", "page_size"=>10}}], "stop_reason"=>"tool_use"}
25
+ # @return [Array] The tool call information
26
+ def extract_tool_call_args(tool_call:)
27
+ tool_call_id = tool_call.dig("id")
28
+ function_name = tool_call.dig("name")
29
+ tool_name, method_name = function_name.split("__")
30
+ tool_arguments = tool_call.dig("input").transform_keys(&:to_sym)
31
+ [tool_call_id, tool_name, method_name, tool_arguments]
32
+ end
33
+
34
+ def build_tools(tools)
35
+ tools.map { |tool| tool.class.function_schemas.to_anthropic_format }.flatten
36
+ end
37
+
38
+ def allowed_tool_choices
39
+ ["auto", "any"]
40
+ end
41
+
42
+ def available_tool_names(tools)
43
+ build_tools(tools).map { |tool| tool.dig(:name) }
44
+ end
45
+
46
+ private
47
+
48
+ def build_tool_choice(choice)
49
+ case choice
50
+ when "auto"
51
+ {type: "auto"}
52
+ when "any"
53
+ {type: "any"}
54
+ else
55
+ {type: "tool", name: choice}
56
+ end
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,62 @@
1
+ module Langchain
2
+ class Assistant
3
+ module LLM
4
+ module Adapters
5
+ class GoogleGemini < Base
6
+ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
+ params = {messages: messages}
8
+ if tools.any?
9
+ params[:tools] = build_tools(tools)
10
+ params[:system] = instructions if instructions
11
+ params[:tool_choice] = build_tool_config(tool_choice)
12
+ end
13
+ params
14
+ end
15
+
16
+ def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
17
+ warn "Image URL is not supported by Google Gemini" if image_url
18
+
19
+ Langchain::Messages::GoogleGeminiMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
20
+ end
21
+
22
+ # Extract the tool call information from the Google Gemini tool call hash
23
+ #
24
+ # @param tool_call [Hash] The tool call hash, format: {"functionCall"=>{"name"=>"weather__execute", "args"=>{"input"=>"NYC"}}}
25
+ # @return [Array] The tool call information
26
+ def extract_tool_call_args(tool_call:)
27
+ tool_call_id = tool_call.dig("functionCall", "name")
28
+ function_name = tool_call.dig("functionCall", "name")
29
+ tool_name, method_name = function_name.split("__")
30
+ tool_arguments = tool_call.dig("functionCall", "args").transform_keys(&:to_sym)
31
+ [tool_call_id, tool_name, method_name, tool_arguments]
32
+ end
33
+
34
+ def build_tools(tools)
35
+ tools.map { |tool| tool.class.function_schemas.to_google_gemini_format }.flatten
36
+ end
37
+
38
+ def allowed_tool_choices
39
+ ["auto", "none"]
40
+ end
41
+
42
+ def available_tool_names(tools)
43
+ build_tools(tools).map { |tool| tool.dig(:name) }
44
+ end
45
+
46
+ private
47
+
48
+ def build_tool_config(choice)
49
+ case choice
50
+ when "auto"
51
+ {function_calling_config: {mode: "auto"}}
52
+ when "none"
53
+ {function_calling_config: {mode: "none"}}
54
+ else
55
+ {function_calling_config: {mode: "any", allowed_function_names: [choice]}}
56
+ end
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,65 @@
1
+ module Langchain
2
+ class Assistant
3
+ module LLM
4
+ module Adapters
5
+ class MistralAI < Base
6
+ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
7
+ params = {messages: messages}
8
+ if tools.any?
9
+ params[:tools] = build_tools(tools)
10
+ params[:tool_choice] = build_tool_choice(tool_choice)
11
+ end
12
+ params
13
+ end
14
+
15
+ def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
16
+ Langchain::Messages::MistralAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
17
+ end
18
+
19
+ # Extract the tool call information from the OpenAI tool call hash
20
+ #
21
+ # @param tool_call [Hash] The tool call hash
22
+ # @return [Array] The tool call information
23
+ def extract_tool_call_args(tool_call:)
24
+ tool_call_id = tool_call.dig("id")
25
+
26
+ function_name = tool_call.dig("function", "name")
27
+ tool_name, method_name = function_name.split("__")
28
+
29
+ tool_arguments = tool_call.dig("function", "arguments")
30
+ tool_arguments = if tool_arguments.is_a?(Hash)
31
+ Langchain::Utils::HashTransformer.symbolize_keys(tool_arguments)
32
+ else
33
+ JSON.parse(tool_arguments, symbolize_names: true)
34
+ end
35
+
36
+ [tool_call_id, tool_name, method_name, tool_arguments]
37
+ end
38
+
39
+ def build_tools(tools)
40
+ tools.map { |tool| tool.class.function_schemas.to_openai_format }.flatten
41
+ end
42
+
43
+ def allowed_tool_choices
44
+ ["auto", "none"]
45
+ end
46
+
47
+ def available_tool_names(tools)
48
+ build_tools(tools).map { |tool| tool.dig(:function, :name) }
49
+ end
50
+
51
+ private
52
+
53
+ def build_tool_choice(choice)
54
+ case choice
55
+ when "auto"
56
+ choice
57
+ else
58
+ {"type" => "function", "function" => {"name" => choice}}
59
+ end
60
+ end
61
+ end
62
+ end
63
+ end
64
+ end
65
+ end