red-candle 1.8.0-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,104 @@
1
+ use candle_core::Result as CandleResult;
2
+ use tokenizers::Tokenizer;
3
+
4
+ pub mod loader;
5
+
6
+ /// Common structure for managing tokenizer
7
+ #[derive(Debug, Clone)]
8
+ pub struct TokenizerWrapper {
9
+ tokenizer: Tokenizer,
10
+ }
11
+
12
+ impl TokenizerWrapper {
13
+ pub fn new(tokenizer: Tokenizer) -> Self {
14
+ Self { tokenizer }
15
+ }
16
+
17
+ pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
18
+ let encoding = self.tokenizer
19
+ .encode(text, add_special_tokens)
20
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
21
+ Ok(encoding.get_ids().to_vec())
22
+ }
23
+
24
+ pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
25
+ self.tokenizer
26
+ .decode(tokens, skip_special_tokens)
27
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
28
+ }
29
+
30
+ pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
31
+ self.tokenizer
32
+ .id_to_token(token)
33
+ .map(|s| s.to_string())
34
+ .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
35
+ }
36
+
37
+ /// Decode a single token for streaming output
38
+ pub fn decode_token(&self, token: u32) -> CandleResult<String> {
39
+ // Decode the single token properly
40
+ self.decode(&[token], true)
41
+ }
42
+
43
+ /// Decode tokens incrementally for streaming
44
+ /// This is more efficient than decoding single tokens
45
+ pub fn decode_incremental(&self, all_tokens: &[u32], new_tokens_start: usize) -> CandleResult<String> {
46
+ if new_tokens_start >= all_tokens.len() {
47
+ return Ok(String::new());
48
+ }
49
+
50
+ // Decode all tokens up to this point
51
+ let full_text = self.decode(all_tokens, true)?;
52
+
53
+ // If we're at the start, return everything
54
+ if new_tokens_start == 0 {
55
+ return Ok(full_text);
56
+ }
57
+
58
+ // Otherwise, decode up to the previous token and return the difference
59
+ let previous_text = self.decode(&all_tokens[..new_tokens_start], true)?;
60
+
61
+ // Find the common prefix between the two strings to handle cases where
62
+ // the tokenizer might produce slightly different text when decoding
63
+ // different token sequences
64
+ let common_len = full_text
65
+ .chars()
66
+ .zip(previous_text.chars())
67
+ .take_while(|(a, b)| a == b)
68
+ .count();
69
+
70
+ Ok(full_text.chars().skip(common_len).collect())
71
+ }
72
+
73
+ /// Format tokens with debug information
74
+ pub fn format_tokens_with_debug(&self, tokens: &[u32]) -> CandleResult<String> {
75
+ let mut result = String::new();
76
+ for &token in tokens {
77
+ let piece = self.token_to_piece(token)?;
78
+ result.push_str(&format!("[{}:{}]", token, piece));
79
+ }
80
+ Ok(result)
81
+ }
82
+
83
+ /// Encode a batch of texts (needed for reranker)
84
+ pub fn encode_batch(&self, texts: Vec<String>, add_special_tokens: bool) -> CandleResult<Vec<Vec<u32>>> {
85
+ let encodings = self.tokenizer
86
+ .encode_batch(texts, add_special_tokens)
87
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer batch error: {}", e)))?;
88
+
89
+ Ok(encodings.into_iter()
90
+ .map(|encoding| encoding.get_ids().to_vec())
91
+ .collect())
92
+ }
93
+
94
+ /// Get the underlying tokenizer (for advanced use cases)
95
+ pub fn inner(&self) -> &Tokenizer {
96
+ &self.tokenizer
97
+ }
98
+
99
+ /// Get a mutable reference to the underlying tokenizer (for configuration)
100
+ pub fn inner_mut(&mut self) -> &mut Tokenizer {
101
+ &mut self.tokenizer
102
+ }
103
+ }
104
+
@@ -0,0 +1,43 @@
1
+ use candle_core::Device as CoreDevice;
2
+
3
+ #[test]
4
+ fn test_device_creation() {
5
+ // CPU device should always work
6
+ let cpu = CoreDevice::Cpu;
7
+ assert!(matches!(cpu, CoreDevice::Cpu));
8
+
9
+ // Test device display
10
+ assert_eq!(format!("{:?}", cpu), "Cpu");
11
+ }
12
+
13
+ #[cfg(feature = "cuda")]
14
+ #[test]
15
+ #[ignore = "requires CUDA hardware"]
16
+ fn test_cuda_device_creation() {
17
+ // This might fail if no CUDA device is available
18
+ match CoreDevice::new_cuda(0) {
19
+ Ok(device) => assert!(matches!(device, CoreDevice::Cuda(_))),
20
+ Err(_) => println!("No CUDA device available for testing"),
21
+ }
22
+ }
23
+
24
+ #[cfg(feature = "metal")]
25
+ #[test]
26
+ #[ignore = "requires Metal hardware"]
27
+ fn test_metal_device_creation() {
28
+ // This might fail if no Metal device is available
29
+ match CoreDevice::new_metal(0) {
30
+ Ok(device) => assert!(matches!(device, CoreDevice::Metal(_))),
31
+ Err(_) => println!("No Metal device available for testing"),
32
+ }
33
+ }
34
+
35
+ #[test]
36
+ fn test_device_matching() {
37
+ let cpu1 = CoreDevice::Cpu;
38
+ let cpu2 = CoreDevice::Cpu;
39
+
40
+ // Same device types should match
41
+ assert!(matches!(cpu1, CoreDevice::Cpu));
42
+ assert!(matches!(cpu2, CoreDevice::Cpu));
43
+ }
@@ -0,0 +1,162 @@
1
+ use candle_core::{Tensor, Device, DType};
2
+
3
+ #[test]
4
+ fn test_tensor_creation() {
5
+ let device = Device::Cpu;
6
+
7
+ // Test tensor creation from slice
8
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
9
+ let tensor = Tensor::new(&data[..], &device).unwrap();
10
+ assert_eq!(tensor.dims(), &[4]);
11
+ assert_eq!(tensor.dtype(), DType::F32);
12
+
13
+ // Test zeros
14
+ let zeros = Tensor::zeros(&[2, 3], DType::F32, &device).unwrap();
15
+ assert_eq!(zeros.dims(), &[2, 3]);
16
+
17
+ // Test ones
18
+ let ones = Tensor::ones(&[3, 2], DType::F32, &device).unwrap();
19
+ assert_eq!(ones.dims(), &[3, 2]);
20
+ }
21
+
22
+ #[test]
23
+ fn test_tensor_arithmetic() {
24
+ let device = Device::Cpu;
25
+
26
+ let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
27
+ let b = Tensor::new(&[4.0f32, 5.0, 6.0], &device).unwrap();
28
+
29
+ // Addition
30
+ let sum = a.add(&b).unwrap();
31
+ let sum_vec: Vec<f32> = sum.to_vec1().unwrap();
32
+ assert_eq!(sum_vec, vec![5.0, 7.0, 9.0]);
33
+
34
+ // Subtraction
35
+ let diff = a.sub(&b).unwrap();
36
+ let diff_vec: Vec<f32> = diff.to_vec1().unwrap();
37
+ assert_eq!(diff_vec, vec![-3.0, -3.0, -3.0]);
38
+
39
+ // Multiplication
40
+ let prod = a.mul(&b).unwrap();
41
+ let prod_vec: Vec<f32> = prod.to_vec1().unwrap();
42
+ assert_eq!(prod_vec, vec![4.0, 10.0, 18.0]);
43
+ }
44
+
45
+ #[test]
46
+ fn test_tensor_reshape() {
47
+ let device = Device::Cpu;
48
+
49
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device).unwrap();
50
+
51
+ // Reshape to 2x3
52
+ let reshaped = tensor.reshape(&[2, 3]).unwrap();
53
+ assert_eq!(reshaped.dims(), &[2, 3]);
54
+
55
+ // Reshape to 3x2
56
+ let reshaped = tensor.reshape(&[3, 2]).unwrap();
57
+ assert_eq!(reshaped.dims(), &[3, 2]);
58
+ }
59
+
60
+ #[test]
61
+ fn test_tensor_transpose() {
62
+ let device = Device::Cpu;
63
+
64
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
65
+ .unwrap()
66
+ .reshape(&[2, 2])
67
+ .unwrap();
68
+
69
+ let transposed = tensor.transpose(0, 1).unwrap();
70
+ assert_eq!(transposed.dims(), &[2, 2]);
71
+
72
+ let values: Vec<f32> = transposed.flatten_all().unwrap().to_vec1().unwrap();
73
+ assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
74
+ }
75
+
76
+ #[test]
77
+ fn test_tensor_reduction() {
78
+ let device = Device::Cpu;
79
+
80
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
81
+
82
+ // Sum
83
+ let sum = tensor.sum_all().unwrap();
84
+ let sum_val: f32 = sum.to_scalar().unwrap();
85
+ assert_eq!(sum_val, 10.0);
86
+
87
+ // Mean
88
+ let mean = tensor.mean_all().unwrap();
89
+ let mean_val: f32 = mean.to_scalar().unwrap();
90
+ assert_eq!(mean_val, 2.5);
91
+ }
92
+
93
+ #[test]
94
+ fn test_tensor_indexing() {
95
+ let device = Device::Cpu;
96
+
97
+ let tensor = Tensor::new(&[10.0f32, 20.0, 30.0, 40.0], &device).unwrap();
98
+
99
+ // Get element at index 0
100
+ let elem = tensor.get(0).unwrap();
101
+ let val: f32 = elem.to_scalar().unwrap();
102
+ assert_eq!(val, 10.0);
103
+
104
+ // Get element at index 2
105
+ let elem = tensor.get(2).unwrap();
106
+ let val: f32 = elem.to_scalar().unwrap();
107
+ assert_eq!(val, 30.0);
108
+ }
109
+
110
+ #[test]
111
+ fn test_tensor_matmul() {
112
+ let device = Device::Cpu;
113
+
114
+ // 2x3 matrix
115
+ let a = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device)
116
+ .unwrap()
117
+ .reshape(&[2, 3])
118
+ .unwrap();
119
+
120
+ // 3x2 matrix
121
+ let b = Tensor::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &device)
122
+ .unwrap()
123
+ .reshape(&[3, 2])
124
+ .unwrap();
125
+
126
+ // Matrix multiplication
127
+ let result = a.matmul(&b).unwrap();
128
+ assert_eq!(result.dims(), &[2, 2]);
129
+
130
+ let values: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
131
+ // [1*7 + 2*9 + 3*11, 1*8 + 2*10 + 3*12, 4*7 + 5*9 + 6*11, 4*8 + 5*10 + 6*12]
132
+ // = [58, 64, 139, 154]
133
+ assert_eq!(values, vec![58.0, 64.0, 139.0, 154.0]);
134
+ }
135
+
136
+ #[test]
137
+ fn test_tensor_where() {
138
+ let device = Device::Cpu;
139
+
140
+ // Create a condition tensor where values > 0 are treated as true
141
+ let cond_values = Tensor::new(&[1.0f32, 0.0, 1.0], &device).unwrap();
142
+ let cond = cond_values.gt(&Tensor::zeros(cond_values.shape(), DType::F32, &device).unwrap()).unwrap();
143
+
144
+ let on_true = Tensor::new(&[10.0f32, 20.0, 30.0], &device).unwrap();
145
+ let on_false = Tensor::new(&[100.0f32, 200.0, 300.0], &device).unwrap();
146
+
147
+ let result = cond.where_cond(&on_true, &on_false).unwrap();
148
+ let values: Vec<f32> = result.to_vec1().unwrap();
149
+ assert_eq!(values, vec![10.0, 200.0, 30.0]);
150
+ }
151
+
152
+ #[test]
153
+ fn test_tensor_narrow() {
154
+ let device = Device::Cpu;
155
+
156
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &device).unwrap();
157
+
158
+ // Narrow from index 1, length 3
159
+ let narrowed = tensor.narrow(0, 1, 3).unwrap();
160
+ let values: Vec<f32> = narrowed.to_vec1().unwrap();
161
+ assert_eq!(values, vec![2.0, 3.0, 4.0]);
162
+ }
Binary file
Binary file
Binary file
Binary file
Binary file
@@ -0,0 +1,68 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module Candle
6
+ class Agent
7
+ MAX_ITERATIONS = 10
8
+
9
+ attr_reader :llm, :tools, :system_prompt, :max_iterations
10
+
11
+ def initialize(llm, tools:, system_prompt: nil, max_iterations: MAX_ITERATIONS)
12
+ @llm = llm
13
+ @tools = tools
14
+ @system_prompt = system_prompt
15
+ @max_iterations = max_iterations
16
+ end
17
+
18
+ def run(user_message, **options)
19
+ messages = []
20
+ messages << { role: "system", content: @system_prompt } if @system_prompt
21
+ messages << { role: "user", content: user_message }
22
+
23
+ iterations = 0
24
+ loop do
25
+ iterations += 1
26
+ if iterations > @max_iterations
27
+ raise AgentMaxIterationsError,
28
+ "Agent exceeded maximum iterations (#{@max_iterations})"
29
+ end
30
+
31
+ result = @llm.chat_with_tools(messages, tools: @tools, execute: true, **options)
32
+
33
+ if result.has_tool_calls?
34
+ # If the model produced a substantial text answer alongside tool calls,
35
+ # treat it as a final response (model is done, trailing tool calls are noise).
36
+ # Strip <think> blocks so they don't count toward the length check.
37
+ text_without_thinking = result.text_response&.gsub(/<think>.*?<\/think>/m, "")&.strip
38
+ if text_without_thinking && text_without_thinking.length > 50
39
+ return AgentResult.new(
40
+ response: result.text_response,
41
+ messages: messages,
42
+ iterations: iterations,
43
+ tool_calls_made: messages.count { |m| m[:role] == "tool" }
44
+ )
45
+ end
46
+
47
+ messages << { role: "assistant", content: result.raw_response }
48
+
49
+ result.tool_results.each do |tr|
50
+ tool_name = tr[:tool_call]&.name || "unknown"
51
+ tool_output = tr[:error] ? "Error: #{tr[:error]}" : JSON.generate(tr[:result])
52
+ messages << { role: "tool", content: "[#{tool_name}] #{tool_output}" }
53
+ end
54
+ else
55
+ return AgentResult.new(
56
+ response: result.text_response || result.raw_response,
57
+ messages: messages,
58
+ iterations: iterations,
59
+ tool_calls_made: messages.count { |m| m[:role] == "tool" }
60
+ )
61
+ end
62
+ end
63
+ end
64
+ end
65
+
66
+ AgentResult = Struct.new(:response, :messages, :iterations, :tool_calls_made, keyword_init: true)
67
+ AgentMaxIterationsError = Class.new(StandardError)
68
+ end
@@ -0,0 +1,67 @@
1
+ module Candle
2
+ module BuildInfo
3
+ def self.display_cuda_info
4
+ info = Candle.build_info
5
+
6
+ # CUDA info is now controlled by logger level
7
+
8
+ if info["cuda_available"] == false
9
+ # :nocov:
10
+ # Check if CUDA could be available on the system
11
+ cuda_potentially_available = ENV['CUDA_ROOT'] || ENV['CUDA_PATH'] ||
12
+ File.exist?('/usr/local/cuda') || File.exist?('/opt/cuda')
13
+
14
+ if cuda_potentially_available
15
+ Candle.logger.warn "=" * 80
16
+ Candle.logger.warn "Red Candle: CUDA detected on system but not enabled in build."
17
+ Candle.logger.warn "This may be due to CANDLE_DISABLE_CUDA being set during installation."
18
+ Candle.logger.warn "To enable CUDA support, reinstall without CANDLE_DISABLE_CUDA set."
19
+ Candle.logger.warn "=" * 80
20
+ end
21
+ # :nocov:
22
+ end
23
+ end
24
+
25
+ def self.cuda_available?
26
+ Candle.build_info["cuda_available"]
27
+ end
28
+
29
+ def self.metal_available?
30
+ Candle.build_info["metal_available"]
31
+ end
32
+
33
+ def self.mkl_available?
34
+ Candle.build_info["mkl_available"]
35
+ end
36
+
37
+ def self.accelerate_available?
38
+ Candle.build_info["accelerate_available"]
39
+ end
40
+
41
+ def self.cudnn_available?
42
+ Candle.build_info["cudnn_available"]
43
+ end
44
+
45
+ def self.summary
46
+ info = Candle.build_info
47
+
48
+ available_backends = []
49
+ available_backends << "Metal" if info["metal_available"]
50
+ available_backends << "CUDA" if info["cuda_available"]
51
+ available_backends << "CPU"
52
+
53
+ {
54
+ default_device: info["default_device"],
55
+ available_backends: available_backends,
56
+ cuda_available: info["cuda_available"],
57
+ metal_available: info["metal_available"],
58
+ mkl_available: info["mkl_available"],
59
+ accelerate_available: info["accelerate_available"],
60
+ cudnn_available: info["cudnn_available"]
61
+ }
62
+ end
63
+ end
64
+ end
65
+
66
+ # Display CUDA info on load
67
+ Candle::BuildInfo.display_cuda_info
@@ -0,0 +1,10 @@
1
+ module Candle
2
+ module DeviceUtils
3
+ # @deprecated Use {Candle::Device.best} instead
4
+ # Get the best available device (Metal > CUDA > CPU)
5
+ def self.best_device
6
+ Candle.logger.warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
7
+ Device.best
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,75 @@
1
+ module Candle
2
+ class EmbeddingModel
3
+ # Default model path for Jina BERT embedding model
4
+ DEFAULT_MODEL_PATH = "jinaai/jina-embeddings-v2-base-en"
5
+
6
+ # Default tokenizer path that works well with the default model
7
+ DEFAULT_TOKENIZER_PATH = "sentence-transformers/all-MiniLM-L6-v2"
8
+
9
+ # Default embedding model type
10
+ DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
11
+
12
+ # Load a pre-trained embedding model from HuggingFace
13
+ # @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
14
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
15
+ # @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
16
+ # @param model_type [String, nil] The type of embedding model (auto-detected if nil)
17
+ # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
+ # @return [EmbeddingModel] A new EmbeddingModel instance
19
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
20
+ # Auto-detect model type based on model_id if not provided
21
+ if model_type.nil?
22
+ model_type = case model_id.downcase
23
+ when /jina/
24
+ "jina_bert"
25
+ when /distilbert/
26
+ "distilbert"
27
+ when /minilm/
28
+ "minilm"
29
+ else
30
+ "standard_bert"
31
+ end
32
+ end
33
+
34
+ # Use model_id as tokenizer if not specified (usually what you want)
35
+ tokenizer_id = tokenizer || model_id
36
+
37
+ _create(model_id, tokenizer_id, device, model_type, embedding_size)
38
+ end
39
+
40
+ # Constructor for creating a new EmbeddingModel with optional parameters
41
+ # @deprecated Use {.from_pretrained} instead
42
+ # @param model_path [String, nil] The path to the model on Hugging Face
43
+ # @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
44
+ # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
45
+ # @param model_type [String, nil] The type of embedding model to use
46
+ # @param embedding_size [Integer, nil] Override for the embedding size (optional)
47
+ def self.new(model_path: DEFAULT_MODEL_PATH,
48
+ tokenizer_path: DEFAULT_TOKENIZER_PATH,
49
+ device: Candle::Device.best,
50
+ model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
51
+ embedding_size: nil)
52
+ $stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
53
+ _create(model_path, tokenizer_path, device, model_type, embedding_size)
54
+ end
55
+ # Returns the embedding for a string using the specified pooling method.
56
+ # @param str [String] The input text
57
+ # @param pooling_method [String] Pooling method: "pooled", "pooled_normalized", or "cls". Default: "pooled_normalized"
58
+ def embedding(str, pooling_method: "pooled_normalized")
59
+ _embedding(str, pooling_method)
60
+ end
61
+
62
+ # Improved inspect method
63
+ def inspect
64
+ opts = options rescue {}
65
+
66
+ parts = ["#<Candle::EmbeddingModel"]
67
+ parts << "model=#{opts["model_id"] || "unknown"}"
68
+ parts << "type=#{opts["model_type"]}" if opts["model_type"]
69
+ parts << "device=#{opts["device"] || "unknown"}"
70
+ parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
71
+
72
+ parts.join(" ") + ">"
73
+ end
74
+ end
75
+ end
@@ -0,0 +1,31 @@
1
+ module Candle
2
+ # Enum for the supported embedding model types
3
+ module EmbeddingModelType
4
+ # Jina Bert embedding models (e.g., jina-embeddings-v2-base-en)
5
+ JINA_BERT = "jina_bert"
6
+
7
+ # Standard BERT embedding models (e.g., bert-base-uncased)
8
+ STANDARD_BERT = "standard_bert"
9
+
10
+ # MiniLM embedding models (e.g., all-MiniLM-L6-v2)
11
+ MINILM = "minilm"
12
+
13
+ # DistilBERT models which can be used for embeddings
14
+ DISTILBERT = "distilbert"
15
+
16
+ # Returns a list of all supported model types
17
+ def self.all
18
+ [JINA_BERT, STANDARD_BERT, DISTILBERT, MINILM]
19
+ end
20
+
21
+ # Returns suggested model paths for each model type
22
+ def self.suggested_model_paths
23
+ {
24
+ JINA_BERT => "jinaai/jina-embeddings-v2-base-en",
25
+ STANDARD_BERT => "scientistcom/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
26
+ MINILM => "sentence-transformers/all-MiniLM-L6-v2",
27
+ DISTILBERT => "scientistcom/distilbert-base-uncased-finetuned-sst-2-english",
28
+ }
29
+ end
30
+ end
31
+ end