red-candle 1.1.1 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ use crate::ruby::{
6
6
  utils::{actual_dim, actual_index},
7
7
  };
8
8
  use crate::ruby::{DType, Device, Result};
9
- use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
9
+ use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
10
10
 
11
11
  #[derive(Clone, Debug)]
12
12
  #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
@@ -21,30 +21,108 @@ impl std::ops::Deref for Tensor {
21
21
  }
22
22
  }
23
23
 
24
+ // Helper functions for tensor operations
25
+ impl Tensor {
26
+ /// Check if device is Metal
27
+ fn is_metal_device(device: &CoreDevice) -> bool {
28
+ matches!(device.location(), DeviceLocation::Metal { .. })
29
+ }
30
+
31
+ /// Convert tensor to target dtype, handling Metal limitations
32
+ fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
33
+ if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
34
+ // Move to CPU first to avoid Metal conversion limitations
35
+ self.0
36
+ .to_device(&CoreDevice::Cpu)
37
+ .map_err(wrap_candle_err)?
38
+ .to_dtype(target_dtype)
39
+ .map_err(wrap_candle_err)
40
+ } else {
41
+ // Direct conversion for CPU or when dtype matches
42
+ self.0
43
+ .to_dtype(target_dtype)
44
+ .map_err(wrap_candle_err)
45
+ }
46
+ }
47
+ }
48
+
24
49
  impl Tensor {
25
50
  pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
26
51
  let dtype = dtype
27
52
  .map(|dtype| DType::from_rbobject(dtype))
28
53
  .unwrap_or(Ok(DType(CoreDType::F32)))?;
29
- let device = device.unwrap_or(Device::Cpu).as_device()?;
30
- // FIXME: Do not use `to_f64` here.
31
- let array = array
32
- .into_iter()
33
- .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
34
- .collect::<Result<Vec<_>>>()?;
35
- Ok(Self(
36
- CoreTensor::new(array.as_slice(), &device)
37
- .map_err(wrap_candle_err)?
38
- .to_dtype(dtype.0)
39
- .map_err(wrap_candle_err)?,
40
- ))
54
+ let device = device.unwrap_or(Device::best()).as_device()?;
55
+
56
+ // Create tensor based on target dtype to avoid conversion issues on Metal
57
+ let tensor = match dtype.0 {
58
+ CoreDType::F32 => {
59
+ // Convert to f32 directly to avoid F64->F32 conversion on Metal
60
+ let array: Vec<f32> = array
61
+ .into_iter()
62
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
63
+ .collect::<Result<Vec<_>>>()?;
64
+ let len = array.len();
65
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
66
+ }
67
+ CoreDType::F64 => {
68
+ let array: Vec<f64> = array
69
+ .into_iter()
70
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
71
+ .collect::<Result<Vec<_>>>()?;
72
+ let len = array.len();
73
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
74
+ }
75
+ CoreDType::I64 => {
76
+ // Convert to i64 directly to avoid conversion issues on Metal
77
+ let array: Vec<i64> = array
78
+ .into_iter()
79
+ .map(|v| {
80
+ // Try integer first, then float
81
+ if let Ok(i) = <i64>::try_convert(v) {
82
+ Ok(i)
83
+ } else if let Ok(f) = magnus::Float::try_convert(v) {
84
+ Ok(f.to_f64() as i64)
85
+ } else {
86
+ Err(magnus::Error::new(
87
+ magnus::exception::type_error(),
88
+ "Cannot convert to i64"
89
+ ))
90
+ }
91
+ })
92
+ .collect::<Result<Vec<_>>>()?;
93
+ let len = array.len();
94
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
95
+ }
96
+ _ => {
97
+ // For other dtypes, create on CPU first if on Metal, then convert
98
+ let cpu_device = CoreDevice::Cpu;
99
+ let use_cpu = Self::is_metal_device(&device);
100
+ let target_device = if use_cpu { &cpu_device } else { &device };
101
+
102
+ let array: Vec<f64> = array
103
+ .into_iter()
104
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
105
+ .collect::<Result<Vec<_>>>()?;
106
+ let tensor = CoreTensor::new(array.as_slice(), target_device)
107
+ .map_err(wrap_candle_err)?
108
+ .to_dtype(dtype.0)
109
+ .map_err(wrap_candle_err)?;
110
+
111
+ // Move to target device if needed
112
+ if use_cpu {
113
+ tensor.to_device(&device).map_err(wrap_candle_err)?
114
+ } else {
115
+ tensor
116
+ }
117
+ }
118
+ };
119
+
120
+ Ok(Self(tensor))
41
121
  }
42
122
 
43
123
  pub fn values(&self) -> Result<Vec<f64>> {
44
- let values = self
45
- .0
46
- .to_dtype(CoreDType::F64)
47
- .map_err(wrap_candle_err)?
124
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
125
+ let values = tensor
48
126
  .flatten_all()
49
127
  .map_err(wrap_candle_err)?
50
128
  .to_vec1()
@@ -92,11 +170,8 @@ impl Tensor {
92
170
  }
93
171
  _ => {
94
172
  // For other dtypes, convert to F64 first
95
- let val: f64 = self.0
96
- .to_dtype(CoreDType::F64)
97
- .map_err(wrap_candle_err)?
98
- .to_vec0()
99
- .map_err(wrap_candle_err)?;
173
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
174
+ let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
100
175
  Ok(val)
101
176
  }
102
177
  }
@@ -541,7 +616,7 @@ impl Tensor {
541
616
  /// Creates a new tensor with random values.
542
617
  /// &RETURNS&: Tensor
543
618
  pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
544
- let device = device.unwrap_or(Device::Cpu).as_device()?;
619
+ let device = device.unwrap_or(Device::best()).as_device()?;
545
620
  Ok(Self(
546
621
  CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
547
622
  ))
@@ -550,7 +625,7 @@ impl Tensor {
550
625
  /// Creates a new tensor with random values from a normal distribution.
551
626
  /// &RETURNS&: Tensor
552
627
  pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
553
- let device = device.unwrap_or(Device::Cpu).as_device()?;
628
+ let device = device.unwrap_or(Device::best()).as_device()?;
554
629
  Ok(Self(
555
630
  CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
556
631
  ))
@@ -559,7 +634,7 @@ impl Tensor {
559
634
  /// Creates a new tensor filled with ones.
560
635
  /// &RETURNS&: Tensor
561
636
  pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
562
- let device = device.unwrap_or(Device::Cpu).as_device()?;
637
+ let device = device.unwrap_or(Device::best()).as_device()?;
563
638
  Ok(Self(
564
639
  CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
565
640
  ))
@@ -567,7 +642,7 @@ impl Tensor {
567
642
  /// Creates a new tensor filled with zeros.
568
643
  /// &RETURNS&: Tensor
569
644
  pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
570
- let device = device.unwrap_or(Device::Cpu).as_device()?;
645
+ let device = device.unwrap_or(Device::best()).as_device()?;
571
646
  Ok(Self(
572
647
  CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
573
648
  ))
@@ -651,4 +726,5 @@ pub fn init(rb_candle: RModule) -> Result<()> {
651
726
  rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
652
727
  rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
653
728
  Ok(())
654
- }
729
+ }
730
+
@@ -105,8 +105,8 @@ impl Tokenizer {
105
105
  }
106
106
 
107
107
  let hash = RHash::new();
108
- hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
- hash.aset("tokens", RArray::from_vec(tokens))?;
108
+ hash.aset(magnus::Symbol::new("ids"), RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
+ hash.aset(magnus::Symbol::new("tokens"), RArray::from_vec(tokens))?;
110
110
 
111
111
  Ok(hash)
112
112
  }
@@ -236,9 +236,65 @@ impl Tokenizer {
236
236
  Ok(hash)
237
237
  }
238
238
 
239
+ /// Get tokenizer options as a hash
240
+ pub fn options(&self) -> Result<RHash> {
241
+ let hash = RHash::new();
242
+
243
+ // Get vocab size
244
+ hash.aset("vocab_size", self.vocab_size(Some(true)))?;
245
+ hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
246
+
247
+ // Get special tokens info
248
+ let special_tokens = self.get_special_tokens()?;
249
+ hash.aset("special_tokens", special_tokens)?;
250
+
251
+ // Get padding/truncation info if available
252
+ let inner_tokenizer = self.0.inner();
253
+
254
+ // Check if padding is enabled
255
+ if let Some(_padding) = inner_tokenizer.get_padding() {
256
+ let padding_info = RHash::new();
257
+ padding_info.aset("enabled", true)?;
258
+ // Note: We can't easily extract all padding params from the tokenizers library
259
+ // but we can indicate it's enabled
260
+ hash.aset("padding", padding_info)?;
261
+ }
262
+
263
+ // Check if truncation is enabled
264
+ if let Some(truncation) = inner_tokenizer.get_truncation() {
265
+ let truncation_info = RHash::new();
266
+ truncation_info.aset("enabled", true)?;
267
+ truncation_info.aset("max_length", truncation.max_length)?;
268
+ hash.aset("truncation", truncation_info)?;
269
+ }
270
+
271
+ Ok(hash)
272
+ }
273
+
239
274
  /// String representation
240
275
  pub fn inspect(&self) -> String {
241
- format!("#<Candle::Tokenizer vocab_size={}>", self.vocab_size(Some(true)))
276
+ let vocab_size = self.vocab_size(Some(true));
277
+ let special_tokens = self.get_special_tokens()
278
+ .ok()
279
+ .map(|h| h.len())
280
+ .unwrap_or(0);
281
+
282
+ let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
283
+
284
+ if special_tokens > 0 {
285
+ parts.push(format!("special_tokens={}", special_tokens));
286
+ }
287
+
288
+ // Check for padding/truncation
289
+ let inner_tokenizer = self.0.inner();
290
+ if inner_tokenizer.get_padding().is_some() {
291
+ parts.push("padding=enabled".to_string());
292
+ }
293
+ if let Some(truncation) = inner_tokenizer.get_truncation() {
294
+ parts.push(format!("truncation={}", truncation.max_length));
295
+ }
296
+
297
+ parts.join(" ") + ">"
242
298
  }
243
299
  }
244
300
 
@@ -262,6 +318,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
262
318
  tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
263
319
  tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
264
320
  tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
321
+ tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
265
322
  tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
266
323
  tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
267
324
 
@@ -100,4 +100,5 @@ impl TokenizerWrapper {
100
100
  pub fn inner_mut(&mut self) -> &mut Tokenizer {
101
101
  &mut self.tokenizer
102
102
  }
103
- }
103
+ }
104
+
@@ -0,0 +1,43 @@
1
+ use candle_core::Device as CoreDevice;
2
+
3
+ #[test]
4
+ fn test_device_creation() {
5
+ // CPU device should always work
6
+ let cpu = CoreDevice::Cpu;
7
+ assert!(matches!(cpu, CoreDevice::Cpu));
8
+
9
+ // Test device display
10
+ assert_eq!(format!("{:?}", cpu), "Cpu");
11
+ }
12
+
13
+ #[cfg(feature = "cuda")]
14
+ #[test]
15
+ #[ignore = "requires CUDA hardware"]
16
+ fn test_cuda_device_creation() {
17
+ // This might fail if no CUDA device is available
18
+ match CoreDevice::new_cuda(0) {
19
+ Ok(device) => assert!(matches!(device, CoreDevice::Cuda(_))),
20
+ Err(_) => println!("No CUDA device available for testing"),
21
+ }
22
+ }
23
+
24
+ #[cfg(feature = "metal")]
25
+ #[test]
26
+ #[ignore = "requires Metal hardware"]
27
+ fn test_metal_device_creation() {
28
+ // This might fail if no Metal device is available
29
+ match CoreDevice::new_metal(0) {
30
+ Ok(device) => assert!(matches!(device, CoreDevice::Metal(_))),
31
+ Err(_) => println!("No Metal device available for testing"),
32
+ }
33
+ }
34
+
35
+ #[test]
36
+ fn test_device_matching() {
37
+ let cpu1 = CoreDevice::Cpu;
38
+ let cpu2 = CoreDevice::Cpu;
39
+
40
+ // Same device types should match
41
+ assert!(matches!(cpu1, CoreDevice::Cpu));
42
+ assert!(matches!(cpu2, CoreDevice::Cpu));
43
+ }
@@ -0,0 +1,162 @@
1
+ use candle_core::{Tensor, Device, DType};
2
+
3
+ #[test]
4
+ fn test_tensor_creation() {
5
+ let device = Device::Cpu;
6
+
7
+ // Test tensor creation from slice
8
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
9
+ let tensor = Tensor::new(&data[..], &device).unwrap();
10
+ assert_eq!(tensor.dims(), &[4]);
11
+ assert_eq!(tensor.dtype(), DType::F32);
12
+
13
+ // Test zeros
14
+ let zeros = Tensor::zeros(&[2, 3], DType::F32, &device).unwrap();
15
+ assert_eq!(zeros.dims(), &[2, 3]);
16
+
17
+ // Test ones
18
+ let ones = Tensor::ones(&[3, 2], DType::F32, &device).unwrap();
19
+ assert_eq!(ones.dims(), &[3, 2]);
20
+ }
21
+
22
+ #[test]
23
+ fn test_tensor_arithmetic() {
24
+ let device = Device::Cpu;
25
+
26
+ let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
27
+ let b = Tensor::new(&[4.0f32, 5.0, 6.0], &device).unwrap();
28
+
29
+ // Addition
30
+ let sum = a.add(&b).unwrap();
31
+ let sum_vec: Vec<f32> = sum.to_vec1().unwrap();
32
+ assert_eq!(sum_vec, vec![5.0, 7.0, 9.0]);
33
+
34
+ // Subtraction
35
+ let diff = a.sub(&b).unwrap();
36
+ let diff_vec: Vec<f32> = diff.to_vec1().unwrap();
37
+ assert_eq!(diff_vec, vec![-3.0, -3.0, -3.0]);
38
+
39
+ // Multiplication
40
+ let prod = a.mul(&b).unwrap();
41
+ let prod_vec: Vec<f32> = prod.to_vec1().unwrap();
42
+ assert_eq!(prod_vec, vec![4.0, 10.0, 18.0]);
43
+ }
44
+
45
+ #[test]
46
+ fn test_tensor_reshape() {
47
+ let device = Device::Cpu;
48
+
49
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device).unwrap();
50
+
51
+ // Reshape to 2x3
52
+ let reshaped = tensor.reshape(&[2, 3]).unwrap();
53
+ assert_eq!(reshaped.dims(), &[2, 3]);
54
+
55
+ // Reshape to 3x2
56
+ let reshaped = tensor.reshape(&[3, 2]).unwrap();
57
+ assert_eq!(reshaped.dims(), &[3, 2]);
58
+ }
59
+
60
+ #[test]
61
+ fn test_tensor_transpose() {
62
+ let device = Device::Cpu;
63
+
64
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
65
+ .unwrap()
66
+ .reshape(&[2, 2])
67
+ .unwrap();
68
+
69
+ let transposed = tensor.transpose(0, 1).unwrap();
70
+ assert_eq!(transposed.dims(), &[2, 2]);
71
+
72
+ let values: Vec<f32> = transposed.flatten_all().unwrap().to_vec1().unwrap();
73
+ assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
74
+ }
75
+
76
+ #[test]
77
+ fn test_tensor_reduction() {
78
+ let device = Device::Cpu;
79
+
80
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
81
+
82
+ // Sum
83
+ let sum = tensor.sum_all().unwrap();
84
+ let sum_val: f32 = sum.to_scalar().unwrap();
85
+ assert_eq!(sum_val, 10.0);
86
+
87
+ // Mean
88
+ let mean = tensor.mean_all().unwrap();
89
+ let mean_val: f32 = mean.to_scalar().unwrap();
90
+ assert_eq!(mean_val, 2.5);
91
+ }
92
+
93
+ #[test]
94
+ fn test_tensor_indexing() {
95
+ let device = Device::Cpu;
96
+
97
+ let tensor = Tensor::new(&[10.0f32, 20.0, 30.0, 40.0], &device).unwrap();
98
+
99
+ // Get element at index 0
100
+ let elem = tensor.get(0).unwrap();
101
+ let val: f32 = elem.to_scalar().unwrap();
102
+ assert_eq!(val, 10.0);
103
+
104
+ // Get element at index 2
105
+ let elem = tensor.get(2).unwrap();
106
+ let val: f32 = elem.to_scalar().unwrap();
107
+ assert_eq!(val, 30.0);
108
+ }
109
+
110
+ #[test]
111
+ fn test_tensor_matmul() {
112
+ let device = Device::Cpu;
113
+
114
+ // 2x3 matrix
115
+ let a = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device)
116
+ .unwrap()
117
+ .reshape(&[2, 3])
118
+ .unwrap();
119
+
120
+ // 3x2 matrix
121
+ let b = Tensor::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &device)
122
+ .unwrap()
123
+ .reshape(&[3, 2])
124
+ .unwrap();
125
+
126
+ // Matrix multiplication
127
+ let result = a.matmul(&b).unwrap();
128
+ assert_eq!(result.dims(), &[2, 2]);
129
+
130
+ let values: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
131
+ // [1*7 + 2*9 + 3*11, 1*8 + 2*10 + 3*12, 4*7 + 5*9 + 6*11, 4*8 + 5*10 + 6*12]
132
+ // = [58, 64, 139, 154]
133
+ assert_eq!(values, vec![58.0, 64.0, 139.0, 154.0]);
134
+ }
135
+
136
+ #[test]
137
+ fn test_tensor_where() {
138
+ let device = Device::Cpu;
139
+
140
+ // Create a condition tensor where values > 0 are treated as true
141
+ let cond_values = Tensor::new(&[1.0f32, 0.0, 1.0], &device).unwrap();
142
+ let cond = cond_values.gt(&Tensor::zeros(cond_values.shape(), DType::F32, &device).unwrap()).unwrap();
143
+
144
+ let on_true = Tensor::new(&[10.0f32, 20.0, 30.0], &device).unwrap();
145
+ let on_false = Tensor::new(&[100.0f32, 200.0, 300.0], &device).unwrap();
146
+
147
+ let result = cond.where_cond(&on_true, &on_false).unwrap();
148
+ let values: Vec<f32> = result.to_vec1().unwrap();
149
+ assert_eq!(values, vec![10.0, 200.0, 30.0]);
150
+ }
151
+
152
+ #[test]
153
+ fn test_tensor_narrow() {
154
+ let device = Device::Cpu;
155
+
156
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &device).unwrap();
157
+
158
+ // Narrow from index 1, length 3
159
+ let narrowed = tensor.narrow(0, 1, 3).unwrap();
160
+ let values: Vec<f32> = narrowed.to_vec1().unwrap();
161
+ assert_eq!(values, vec![2.0, 3.0, 4.0]);
162
+ }
@@ -1,22 +1,10 @@
1
1
  module Candle
2
2
  module DeviceUtils
3
+ # @deprecated Use {Candle::Device.best} instead
3
4
  # Get the best available device (Metal > CUDA > CPU)
4
5
  def self.best_device
5
- # Try devices in order of preference
6
- begin
7
- # Try Metal first (for Mac users)
8
- Device.metal
9
- rescue
10
- # :nocov:
11
- begin
12
- # Try CUDA next (for NVIDIA GPU users)
13
- Device.cuda
14
- rescue
15
- # Fall back to CPU
16
- Device.cpu
17
- end
18
- # :nocov:
19
- end
6
+ warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
7
+ Device.best
20
8
  end
21
9
  end
22
10
  end
@@ -9,7 +9,36 @@ module Candle
9
9
  # Default embedding model type
10
10
  DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
11
11
 
12
+ # Load a pre-trained embedding model from HuggingFace
13
+ # @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
14
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
15
+ # @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
16
+ # @param model_type [String, nil] The type of embedding model (auto-detected if nil)
17
+ # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
+ # @return [EmbeddingModel] A new EmbeddingModel instance
19
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
20
+ # Auto-detect model type based on model_id if not provided
21
+ if model_type.nil?
22
+ model_type = case model_id.downcase
23
+ when /jina/
24
+ "jina_bert"
25
+ when /distilbert/
26
+ "distilbert"
27
+ when /minilm/
28
+ "minilm"
29
+ else
30
+ "standard_bert"
31
+ end
32
+ end
33
+
34
+ # Use model_id as tokenizer if not specified (usually what you want)
35
+ tokenizer_id = tokenizer || model_id
36
+
37
+ _create(model_id, tokenizer_id, device, model_type, embedding_size)
38
+ end
39
+
12
40
  # Constructor for creating a new EmbeddingModel with optional parameters
41
+ # @deprecated Use {.from_pretrained} instead
13
42
  # @param model_path [String, nil] The path to the model on Hugging Face
14
43
  # @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
15
44
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
@@ -17,9 +46,10 @@ module Candle
17
46
  # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
47
  def self.new(model_path: DEFAULT_MODEL_PATH,
19
48
  tokenizer_path: DEFAULT_TOKENIZER_PATH,
20
- device: Candle::Device.cpu,
49
+ device: Candle::Device.best,
21
50
  model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
22
51
  embedding_size: nil)
52
+ $stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
23
53
  _create(model_path, tokenizer_path, device, model_type, embedding_size)
24
54
  end
25
55
  # Returns the embedding for a string using the specified pooling method.
@@ -28,5 +58,18 @@ module Candle
28
58
  def embedding(str, pooling_method: "pooled_normalized")
29
59
  _embedding(str, pooling_method)
30
60
  end
61
+
62
+ # Improved inspect method
63
+ def inspect
64
+ opts = options rescue {}
65
+
66
+ parts = ["#<Candle::EmbeddingModel"]
67
+ parts << "model=#{opts["model_id"] || "unknown"}"
68
+ parts << "type=#{opts["model_type"]}" if opts["model_type"]
69
+ parts << "device=#{opts["device"] || "unknown"}"
70
+ parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
71
+
72
+ parts.join(" ") + ">"
73
+ end
31
74
  end
32
75
  end
data/lib/candle/llm.rb CHANGED
@@ -189,6 +189,45 @@ module Candle
189
189
  prompt = apply_chat_template(messages)
190
190
  generate_stream(prompt, **options, &block)
191
191
  end
192
+
193
+ # Inspect method for debugging and exploration
194
+ def inspect
195
+ opts = options rescue {}
196
+
197
+ # Extract key information
198
+ model_type = opts["model_type"] || "Unknown"
199
+ device = opts["device"] || self.device.to_s rescue "unknown"
200
+
201
+ # Build the inspect string
202
+ parts = ["#<Candle::LLM"]
203
+
204
+ # Add base model or model_id
205
+ if opts["base_model"]
206
+ parts << "model=#{opts["base_model"]}"
207
+ elsif opts["model_id"]
208
+ parts << "model=#{opts["model_id"]}"
209
+ elsif respond_to?(:model_id)
210
+ parts << "model=#{model_id}"
211
+ end
212
+
213
+ # Add GGUF file if present
214
+ if opts["gguf_file"]
215
+ parts << "gguf=#{opts["gguf_file"]}"
216
+ end
217
+
218
+ # Add device
219
+ parts << "device=#{device}"
220
+
221
+ # Add model type
222
+ parts << "type=#{model_type}"
223
+
224
+ # Add architecture for GGUF models
225
+ if opts["architecture"]
226
+ parts << "arch=#{opts["architecture"]}"
227
+ end
228
+
229
+ parts.join(" ") + ">"
230
+ end
192
231
 
193
232
  def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
194
233
  begin
@@ -206,7 +245,7 @@ module Candle
206
245
  end
207
246
  end
208
247
 
209
- def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
248
+ def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
210
249
  model_str = if gguf_file
211
250
  "#{model_id}@#{gguf_file}"
212
251
  else
@@ -393,5 +432,28 @@ module Candle
393
432
  }
394
433
  new(defaults.merge(opts))
395
434
  end
435
+
436
+ # Inspect method for debugging and exploration
437
+ def inspect
438
+ opts = options rescue {}
439
+
440
+ parts = ["#<Candle::GenerationConfig"]
441
+
442
+ # Add key configuration parameters
443
+ parts << "temp=#{opts["temperature"]}" if opts["temperature"]
444
+ parts << "max=#{opts["max_length"]}" if opts["max_length"]
445
+ parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
446
+ parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
447
+ parts << "seed=#{opts["seed"]}" if opts["seed"]
448
+
449
+ # Add flags
450
+ flags = []
451
+ flags << "debug" if opts["debug_tokens"]
452
+ flags << "constraint" if opts["has_constraint"]
453
+ flags << "stop_on_match" if opts["stop_on_match"]
454
+ parts << "flags=[#{flags.join(",")}]" if flags.any?
455
+
456
+ parts.join(" ") + ">"
457
+ end
396
458
  end
397
459
  end