red-candle 1.1.1 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +102 -45
- data/Rakefile +108 -77
- data/ext/candle/src/lib.rs +2 -4
- data/ext/candle/src/llm/quantized_gguf.rs +18 -2
- data/ext/candle/src/ruby/device.rs +32 -1
- data/ext/candle/src/ruby/dtype.rs +1 -0
- data/ext/candle/src/ruby/embedding_model.rs +74 -28
- data/ext/candle/src/ruby/errors.rs +1 -0
- data/ext/candle/src/ruby/llm.rs +96 -1
- data/ext/candle/src/ruby/mod.rs +2 -0
- data/ext/candle/src/{ner.rs → ruby/ner.rs} +47 -15
- data/ext/candle/src/{reranker.rs → ruby/reranker.rs} +24 -2
- data/ext/candle/src/ruby/tensor.rs +103 -27
- data/ext/candle/src/ruby/tokenizer.rs +60 -3
- data/ext/candle/src/tokenizer/mod.rs +2 -1
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/device_utils.rb +3 -15
- data/lib/candle/embedding_model.rb +44 -1
- data/lib/candle/llm.rb +63 -1
- data/lib/candle/ner.rb +34 -22
- data/lib/candle/reranker.rb +20 -1
- data/lib/candle/tensor.rb +15 -0
- data/lib/candle/version.rb +1 -1
- metadata +20 -4
@@ -6,7 +6,7 @@ use crate::ruby::{
|
|
6
6
|
utils::{actual_dim, actual_index},
|
7
7
|
};
|
8
8
|
use crate::ruby::{DType, Device, Result};
|
9
|
-
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
|
9
|
+
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
|
10
10
|
|
11
11
|
#[derive(Clone, Debug)]
|
12
12
|
#[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
|
@@ -21,30 +21,108 @@ impl std::ops::Deref for Tensor {
|
|
21
21
|
}
|
22
22
|
}
|
23
23
|
|
24
|
+
// Helper functions for tensor operations
|
25
|
+
impl Tensor {
|
26
|
+
/// Check if device is Metal
|
27
|
+
fn is_metal_device(device: &CoreDevice) -> bool {
|
28
|
+
matches!(device.location(), DeviceLocation::Metal { .. })
|
29
|
+
}
|
30
|
+
|
31
|
+
/// Convert tensor to target dtype, handling Metal limitations
|
32
|
+
fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
|
33
|
+
if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
|
34
|
+
// Move to CPU first to avoid Metal conversion limitations
|
35
|
+
self.0
|
36
|
+
.to_device(&CoreDevice::Cpu)
|
37
|
+
.map_err(wrap_candle_err)?
|
38
|
+
.to_dtype(target_dtype)
|
39
|
+
.map_err(wrap_candle_err)
|
40
|
+
} else {
|
41
|
+
// Direct conversion for CPU or when dtype matches
|
42
|
+
self.0
|
43
|
+
.to_dtype(target_dtype)
|
44
|
+
.map_err(wrap_candle_err)
|
45
|
+
}
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
24
49
|
impl Tensor {
|
25
50
|
pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
|
26
51
|
let dtype = dtype
|
27
52
|
.map(|dtype| DType::from_rbobject(dtype))
|
28
53
|
.unwrap_or(Ok(DType(CoreDType::F32)))?;
|
29
|
-
let device = device.unwrap_or(Device::
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
.
|
40
|
-
|
54
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
55
|
+
|
56
|
+
// Create tensor based on target dtype to avoid conversion issues on Metal
|
57
|
+
let tensor = match dtype.0 {
|
58
|
+
CoreDType::F32 => {
|
59
|
+
// Convert to f32 directly to avoid F64->F32 conversion on Metal
|
60
|
+
let array: Vec<f32> = array
|
61
|
+
.into_iter()
|
62
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
|
63
|
+
.collect::<Result<Vec<_>>>()?;
|
64
|
+
let len = array.len();
|
65
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
66
|
+
}
|
67
|
+
CoreDType::F64 => {
|
68
|
+
let array: Vec<f64> = array
|
69
|
+
.into_iter()
|
70
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
71
|
+
.collect::<Result<Vec<_>>>()?;
|
72
|
+
let len = array.len();
|
73
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
74
|
+
}
|
75
|
+
CoreDType::I64 => {
|
76
|
+
// Convert to i64 directly to avoid conversion issues on Metal
|
77
|
+
let array: Vec<i64> = array
|
78
|
+
.into_iter()
|
79
|
+
.map(|v| {
|
80
|
+
// Try integer first, then float
|
81
|
+
if let Ok(i) = <i64>::try_convert(v) {
|
82
|
+
Ok(i)
|
83
|
+
} else if let Ok(f) = magnus::Float::try_convert(v) {
|
84
|
+
Ok(f.to_f64() as i64)
|
85
|
+
} else {
|
86
|
+
Err(magnus::Error::new(
|
87
|
+
magnus::exception::type_error(),
|
88
|
+
"Cannot convert to i64"
|
89
|
+
))
|
90
|
+
}
|
91
|
+
})
|
92
|
+
.collect::<Result<Vec<_>>>()?;
|
93
|
+
let len = array.len();
|
94
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
95
|
+
}
|
96
|
+
_ => {
|
97
|
+
// For other dtypes, create on CPU first if on Metal, then convert
|
98
|
+
let cpu_device = CoreDevice::Cpu;
|
99
|
+
let use_cpu = Self::is_metal_device(&device);
|
100
|
+
let target_device = if use_cpu { &cpu_device } else { &device };
|
101
|
+
|
102
|
+
let array: Vec<f64> = array
|
103
|
+
.into_iter()
|
104
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
105
|
+
.collect::<Result<Vec<_>>>()?;
|
106
|
+
let tensor = CoreTensor::new(array.as_slice(), target_device)
|
107
|
+
.map_err(wrap_candle_err)?
|
108
|
+
.to_dtype(dtype.0)
|
109
|
+
.map_err(wrap_candle_err)?;
|
110
|
+
|
111
|
+
// Move to target device if needed
|
112
|
+
if use_cpu {
|
113
|
+
tensor.to_device(&device).map_err(wrap_candle_err)?
|
114
|
+
} else {
|
115
|
+
tensor
|
116
|
+
}
|
117
|
+
}
|
118
|
+
};
|
119
|
+
|
120
|
+
Ok(Self(tensor))
|
41
121
|
}
|
42
122
|
|
43
123
|
pub fn values(&self) -> Result<Vec<f64>> {
|
44
|
-
let
|
45
|
-
|
46
|
-
.to_dtype(CoreDType::F64)
|
47
|
-
.map_err(wrap_candle_err)?
|
124
|
+
let tensor = self.safe_to_dtype(CoreDType::F64)?;
|
125
|
+
let values = tensor
|
48
126
|
.flatten_all()
|
49
127
|
.map_err(wrap_candle_err)?
|
50
128
|
.to_vec1()
|
@@ -92,11 +170,8 @@ impl Tensor {
|
|
92
170
|
}
|
93
171
|
_ => {
|
94
172
|
// For other dtypes, convert to F64 first
|
95
|
-
let
|
96
|
-
|
97
|
-
.map_err(wrap_candle_err)?
|
98
|
-
.to_vec0()
|
99
|
-
.map_err(wrap_candle_err)?;
|
173
|
+
let tensor = self.safe_to_dtype(CoreDType::F64)?;
|
174
|
+
let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
|
100
175
|
Ok(val)
|
101
176
|
}
|
102
177
|
}
|
@@ -541,7 +616,7 @@ impl Tensor {
|
|
541
616
|
/// Creates a new tensor with random values.
|
542
617
|
/// &RETURNS&: Tensor
|
543
618
|
pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
544
|
-
let device = device.unwrap_or(Device::
|
619
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
545
620
|
Ok(Self(
|
546
621
|
CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
547
622
|
))
|
@@ -550,7 +625,7 @@ impl Tensor {
|
|
550
625
|
/// Creates a new tensor with random values from a normal distribution.
|
551
626
|
/// &RETURNS&: Tensor
|
552
627
|
pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
553
|
-
let device = device.unwrap_or(Device::
|
628
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
554
629
|
Ok(Self(
|
555
630
|
CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
556
631
|
))
|
@@ -559,7 +634,7 @@ impl Tensor {
|
|
559
634
|
/// Creates a new tensor filled with ones.
|
560
635
|
/// &RETURNS&: Tensor
|
561
636
|
pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
562
|
-
let device = device.unwrap_or(Device::
|
637
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
563
638
|
Ok(Self(
|
564
639
|
CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
565
640
|
))
|
@@ -567,7 +642,7 @@ impl Tensor {
|
|
567
642
|
/// Creates a new tensor filled with zeros.
|
568
643
|
/// &RETURNS&: Tensor
|
569
644
|
pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
570
|
-
let device = device.unwrap_or(Device::
|
645
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
571
646
|
Ok(Self(
|
572
647
|
CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
573
648
|
))
|
@@ -651,4 +726,5 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
651
726
|
rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
|
652
727
|
rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
|
653
728
|
Ok(())
|
654
|
-
}
|
729
|
+
}
|
730
|
+
|
@@ -105,8 +105,8 @@ impl Tokenizer {
|
|
105
105
|
}
|
106
106
|
|
107
107
|
let hash = RHash::new();
|
108
|
-
hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
-
hash.aset("tokens", RArray::from_vec(tokens))?;
|
108
|
+
hash.aset(magnus::Symbol::new("ids"), RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
+
hash.aset(magnus::Symbol::new("tokens"), RArray::from_vec(tokens))?;
|
110
110
|
|
111
111
|
Ok(hash)
|
112
112
|
}
|
@@ -236,9 +236,65 @@ impl Tokenizer {
|
|
236
236
|
Ok(hash)
|
237
237
|
}
|
238
238
|
|
239
|
+
/// Get tokenizer options as a hash
|
240
|
+
pub fn options(&self) -> Result<RHash> {
|
241
|
+
let hash = RHash::new();
|
242
|
+
|
243
|
+
// Get vocab size
|
244
|
+
hash.aset("vocab_size", self.vocab_size(Some(true)))?;
|
245
|
+
hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
|
246
|
+
|
247
|
+
// Get special tokens info
|
248
|
+
let special_tokens = self.get_special_tokens()?;
|
249
|
+
hash.aset("special_tokens", special_tokens)?;
|
250
|
+
|
251
|
+
// Get padding/truncation info if available
|
252
|
+
let inner_tokenizer = self.0.inner();
|
253
|
+
|
254
|
+
// Check if padding is enabled
|
255
|
+
if let Some(_padding) = inner_tokenizer.get_padding() {
|
256
|
+
let padding_info = RHash::new();
|
257
|
+
padding_info.aset("enabled", true)?;
|
258
|
+
// Note: We can't easily extract all padding params from the tokenizers library
|
259
|
+
// but we can indicate it's enabled
|
260
|
+
hash.aset("padding", padding_info)?;
|
261
|
+
}
|
262
|
+
|
263
|
+
// Check if truncation is enabled
|
264
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
265
|
+
let truncation_info = RHash::new();
|
266
|
+
truncation_info.aset("enabled", true)?;
|
267
|
+
truncation_info.aset("max_length", truncation.max_length)?;
|
268
|
+
hash.aset("truncation", truncation_info)?;
|
269
|
+
}
|
270
|
+
|
271
|
+
Ok(hash)
|
272
|
+
}
|
273
|
+
|
239
274
|
/// String representation
|
240
275
|
pub fn inspect(&self) -> String {
|
241
|
-
|
276
|
+
let vocab_size = self.vocab_size(Some(true));
|
277
|
+
let special_tokens = self.get_special_tokens()
|
278
|
+
.ok()
|
279
|
+
.map(|h| h.len())
|
280
|
+
.unwrap_or(0);
|
281
|
+
|
282
|
+
let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
|
283
|
+
|
284
|
+
if special_tokens > 0 {
|
285
|
+
parts.push(format!("special_tokens={}", special_tokens));
|
286
|
+
}
|
287
|
+
|
288
|
+
// Check for padding/truncation
|
289
|
+
let inner_tokenizer = self.0.inner();
|
290
|
+
if inner_tokenizer.get_padding().is_some() {
|
291
|
+
parts.push("padding=enabled".to_string());
|
292
|
+
}
|
293
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
294
|
+
parts.push(format!("truncation={}", truncation.max_length));
|
295
|
+
}
|
296
|
+
|
297
|
+
parts.join(" ") + ">"
|
242
298
|
}
|
243
299
|
}
|
244
300
|
|
@@ -262,6 +318,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
262
318
|
tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
|
263
319
|
tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
|
264
320
|
tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
|
321
|
+
tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
|
265
322
|
tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
|
266
323
|
tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
|
267
324
|
|
@@ -0,0 +1,43 @@
|
|
1
|
+
use candle_core::Device as CoreDevice;
|
2
|
+
|
3
|
+
#[test]
|
4
|
+
fn test_device_creation() {
|
5
|
+
// CPU device should always work
|
6
|
+
let cpu = CoreDevice::Cpu;
|
7
|
+
assert!(matches!(cpu, CoreDevice::Cpu));
|
8
|
+
|
9
|
+
// Test device display
|
10
|
+
assert_eq!(format!("{:?}", cpu), "Cpu");
|
11
|
+
}
|
12
|
+
|
13
|
+
#[cfg(feature = "cuda")]
|
14
|
+
#[test]
|
15
|
+
#[ignore = "requires CUDA hardware"]
|
16
|
+
fn test_cuda_device_creation() {
|
17
|
+
// This might fail if no CUDA device is available
|
18
|
+
match CoreDevice::new_cuda(0) {
|
19
|
+
Ok(device) => assert!(matches!(device, CoreDevice::Cuda(_))),
|
20
|
+
Err(_) => println!("No CUDA device available for testing"),
|
21
|
+
}
|
22
|
+
}
|
23
|
+
|
24
|
+
#[cfg(feature = "metal")]
|
25
|
+
#[test]
|
26
|
+
#[ignore = "requires Metal hardware"]
|
27
|
+
fn test_metal_device_creation() {
|
28
|
+
// This might fail if no Metal device is available
|
29
|
+
match CoreDevice::new_metal(0) {
|
30
|
+
Ok(device) => assert!(matches!(device, CoreDevice::Metal(_))),
|
31
|
+
Err(_) => println!("No Metal device available for testing"),
|
32
|
+
}
|
33
|
+
}
|
34
|
+
|
35
|
+
#[test]
|
36
|
+
fn test_device_matching() {
|
37
|
+
let cpu1 = CoreDevice::Cpu;
|
38
|
+
let cpu2 = CoreDevice::Cpu;
|
39
|
+
|
40
|
+
// Same device types should match
|
41
|
+
assert!(matches!(cpu1, CoreDevice::Cpu));
|
42
|
+
assert!(matches!(cpu2, CoreDevice::Cpu));
|
43
|
+
}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
use candle_core::{Tensor, Device, DType};
|
2
|
+
|
3
|
+
#[test]
|
4
|
+
fn test_tensor_creation() {
|
5
|
+
let device = Device::Cpu;
|
6
|
+
|
7
|
+
// Test tensor creation from slice
|
8
|
+
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
9
|
+
let tensor = Tensor::new(&data[..], &device).unwrap();
|
10
|
+
assert_eq!(tensor.dims(), &[4]);
|
11
|
+
assert_eq!(tensor.dtype(), DType::F32);
|
12
|
+
|
13
|
+
// Test zeros
|
14
|
+
let zeros = Tensor::zeros(&[2, 3], DType::F32, &device).unwrap();
|
15
|
+
assert_eq!(zeros.dims(), &[2, 3]);
|
16
|
+
|
17
|
+
// Test ones
|
18
|
+
let ones = Tensor::ones(&[3, 2], DType::F32, &device).unwrap();
|
19
|
+
assert_eq!(ones.dims(), &[3, 2]);
|
20
|
+
}
|
21
|
+
|
22
|
+
#[test]
|
23
|
+
fn test_tensor_arithmetic() {
|
24
|
+
let device = Device::Cpu;
|
25
|
+
|
26
|
+
let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
|
27
|
+
let b = Tensor::new(&[4.0f32, 5.0, 6.0], &device).unwrap();
|
28
|
+
|
29
|
+
// Addition
|
30
|
+
let sum = a.add(&b).unwrap();
|
31
|
+
let sum_vec: Vec<f32> = sum.to_vec1().unwrap();
|
32
|
+
assert_eq!(sum_vec, vec![5.0, 7.0, 9.0]);
|
33
|
+
|
34
|
+
// Subtraction
|
35
|
+
let diff = a.sub(&b).unwrap();
|
36
|
+
let diff_vec: Vec<f32> = diff.to_vec1().unwrap();
|
37
|
+
assert_eq!(diff_vec, vec![-3.0, -3.0, -3.0]);
|
38
|
+
|
39
|
+
// Multiplication
|
40
|
+
let prod = a.mul(&b).unwrap();
|
41
|
+
let prod_vec: Vec<f32> = prod.to_vec1().unwrap();
|
42
|
+
assert_eq!(prod_vec, vec![4.0, 10.0, 18.0]);
|
43
|
+
}
|
44
|
+
|
45
|
+
#[test]
|
46
|
+
fn test_tensor_reshape() {
|
47
|
+
let device = Device::Cpu;
|
48
|
+
|
49
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device).unwrap();
|
50
|
+
|
51
|
+
// Reshape to 2x3
|
52
|
+
let reshaped = tensor.reshape(&[2, 3]).unwrap();
|
53
|
+
assert_eq!(reshaped.dims(), &[2, 3]);
|
54
|
+
|
55
|
+
// Reshape to 3x2
|
56
|
+
let reshaped = tensor.reshape(&[3, 2]).unwrap();
|
57
|
+
assert_eq!(reshaped.dims(), &[3, 2]);
|
58
|
+
}
|
59
|
+
|
60
|
+
#[test]
|
61
|
+
fn test_tensor_transpose() {
|
62
|
+
let device = Device::Cpu;
|
63
|
+
|
64
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
|
65
|
+
.unwrap()
|
66
|
+
.reshape(&[2, 2])
|
67
|
+
.unwrap();
|
68
|
+
|
69
|
+
let transposed = tensor.transpose(0, 1).unwrap();
|
70
|
+
assert_eq!(transposed.dims(), &[2, 2]);
|
71
|
+
|
72
|
+
let values: Vec<f32> = transposed.flatten_all().unwrap().to_vec1().unwrap();
|
73
|
+
assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
|
74
|
+
}
|
75
|
+
|
76
|
+
#[test]
|
77
|
+
fn test_tensor_reduction() {
|
78
|
+
let device = Device::Cpu;
|
79
|
+
|
80
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
|
81
|
+
|
82
|
+
// Sum
|
83
|
+
let sum = tensor.sum_all().unwrap();
|
84
|
+
let sum_val: f32 = sum.to_scalar().unwrap();
|
85
|
+
assert_eq!(sum_val, 10.0);
|
86
|
+
|
87
|
+
// Mean
|
88
|
+
let mean = tensor.mean_all().unwrap();
|
89
|
+
let mean_val: f32 = mean.to_scalar().unwrap();
|
90
|
+
assert_eq!(mean_val, 2.5);
|
91
|
+
}
|
92
|
+
|
93
|
+
#[test]
|
94
|
+
fn test_tensor_indexing() {
|
95
|
+
let device = Device::Cpu;
|
96
|
+
|
97
|
+
let tensor = Tensor::new(&[10.0f32, 20.0, 30.0, 40.0], &device).unwrap();
|
98
|
+
|
99
|
+
// Get element at index 0
|
100
|
+
let elem = tensor.get(0).unwrap();
|
101
|
+
let val: f32 = elem.to_scalar().unwrap();
|
102
|
+
assert_eq!(val, 10.0);
|
103
|
+
|
104
|
+
// Get element at index 2
|
105
|
+
let elem = tensor.get(2).unwrap();
|
106
|
+
let val: f32 = elem.to_scalar().unwrap();
|
107
|
+
assert_eq!(val, 30.0);
|
108
|
+
}
|
109
|
+
|
110
|
+
#[test]
|
111
|
+
fn test_tensor_matmul() {
|
112
|
+
let device = Device::Cpu;
|
113
|
+
|
114
|
+
// 2x3 matrix
|
115
|
+
let a = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device)
|
116
|
+
.unwrap()
|
117
|
+
.reshape(&[2, 3])
|
118
|
+
.unwrap();
|
119
|
+
|
120
|
+
// 3x2 matrix
|
121
|
+
let b = Tensor::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &device)
|
122
|
+
.unwrap()
|
123
|
+
.reshape(&[3, 2])
|
124
|
+
.unwrap();
|
125
|
+
|
126
|
+
// Matrix multiplication
|
127
|
+
let result = a.matmul(&b).unwrap();
|
128
|
+
assert_eq!(result.dims(), &[2, 2]);
|
129
|
+
|
130
|
+
let values: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
|
131
|
+
// [1*7 + 2*9 + 3*11, 1*8 + 2*10 + 3*12, 4*7 + 5*9 + 6*11, 4*8 + 5*10 + 6*12]
|
132
|
+
// = [58, 64, 139, 154]
|
133
|
+
assert_eq!(values, vec![58.0, 64.0, 139.0, 154.0]);
|
134
|
+
}
|
135
|
+
|
136
|
+
#[test]
|
137
|
+
fn test_tensor_where() {
|
138
|
+
let device = Device::Cpu;
|
139
|
+
|
140
|
+
// Create a condition tensor where values > 0 are treated as true
|
141
|
+
let cond_values = Tensor::new(&[1.0f32, 0.0, 1.0], &device).unwrap();
|
142
|
+
let cond = cond_values.gt(&Tensor::zeros(cond_values.shape(), DType::F32, &device).unwrap()).unwrap();
|
143
|
+
|
144
|
+
let on_true = Tensor::new(&[10.0f32, 20.0, 30.0], &device).unwrap();
|
145
|
+
let on_false = Tensor::new(&[100.0f32, 200.0, 300.0], &device).unwrap();
|
146
|
+
|
147
|
+
let result = cond.where_cond(&on_true, &on_false).unwrap();
|
148
|
+
let values: Vec<f32> = result.to_vec1().unwrap();
|
149
|
+
assert_eq!(values, vec![10.0, 200.0, 30.0]);
|
150
|
+
}
|
151
|
+
|
152
|
+
#[test]
|
153
|
+
fn test_tensor_narrow() {
|
154
|
+
let device = Device::Cpu;
|
155
|
+
|
156
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &device).unwrap();
|
157
|
+
|
158
|
+
// Narrow from index 1, length 3
|
159
|
+
let narrowed = tensor.narrow(0, 1, 3).unwrap();
|
160
|
+
let values: Vec<f32> = narrowed.to_vec1().unwrap();
|
161
|
+
assert_eq!(values, vec![2.0, 3.0, 4.0]);
|
162
|
+
}
|
data/lib/candle/device_utils.rb
CHANGED
@@ -1,22 +1,10 @@
|
|
1
1
|
module Candle
|
2
2
|
module DeviceUtils
|
3
|
+
# @deprecated Use {Candle::Device.best} instead
|
3
4
|
# Get the best available device (Metal > CUDA > CPU)
|
4
5
|
def self.best_device
|
5
|
-
|
6
|
-
|
7
|
-
# Try Metal first (for Mac users)
|
8
|
-
Device.metal
|
9
|
-
rescue
|
10
|
-
# :nocov:
|
11
|
-
begin
|
12
|
-
# Try CUDA next (for NVIDIA GPU users)
|
13
|
-
Device.cuda
|
14
|
-
rescue
|
15
|
-
# Fall back to CPU
|
16
|
-
Device.cpu
|
17
|
-
end
|
18
|
-
# :nocov:
|
19
|
-
end
|
6
|
+
warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
|
7
|
+
Device.best
|
20
8
|
end
|
21
9
|
end
|
22
10
|
end
|
@@ -9,7 +9,36 @@ module Candle
|
|
9
9
|
# Default embedding model type
|
10
10
|
DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
|
11
11
|
|
12
|
+
# Load a pre-trained embedding model from HuggingFace
|
13
|
+
# @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
|
14
|
+
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
15
|
+
# @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
|
16
|
+
# @param model_type [String, nil] The type of embedding model (auto-detected if nil)
|
17
|
+
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
|
+
# @return [EmbeddingModel] A new EmbeddingModel instance
|
19
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
|
20
|
+
# Auto-detect model type based on model_id if not provided
|
21
|
+
if model_type.nil?
|
22
|
+
model_type = case model_id.downcase
|
23
|
+
when /jina/
|
24
|
+
"jina_bert"
|
25
|
+
when /distilbert/
|
26
|
+
"distilbert"
|
27
|
+
when /minilm/
|
28
|
+
"minilm"
|
29
|
+
else
|
30
|
+
"standard_bert"
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
# Use model_id as tokenizer if not specified (usually what you want)
|
35
|
+
tokenizer_id = tokenizer || model_id
|
36
|
+
|
37
|
+
_create(model_id, tokenizer_id, device, model_type, embedding_size)
|
38
|
+
end
|
39
|
+
|
12
40
|
# Constructor for creating a new EmbeddingModel with optional parameters
|
41
|
+
# @deprecated Use {.from_pretrained} instead
|
13
42
|
# @param model_path [String, nil] The path to the model on Hugging Face
|
14
43
|
# @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
|
15
44
|
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
@@ -17,9 +46,10 @@ module Candle
|
|
17
46
|
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
47
|
def self.new(model_path: DEFAULT_MODEL_PATH,
|
19
48
|
tokenizer_path: DEFAULT_TOKENIZER_PATH,
|
20
|
-
device: Candle::Device.
|
49
|
+
device: Candle::Device.best,
|
21
50
|
model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
|
22
51
|
embedding_size: nil)
|
52
|
+
$stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
|
23
53
|
_create(model_path, tokenizer_path, device, model_type, embedding_size)
|
24
54
|
end
|
25
55
|
# Returns the embedding for a string using the specified pooling method.
|
@@ -28,5 +58,18 @@ module Candle
|
|
28
58
|
def embedding(str, pooling_method: "pooled_normalized")
|
29
59
|
_embedding(str, pooling_method)
|
30
60
|
end
|
61
|
+
|
62
|
+
# Improved inspect method
|
63
|
+
def inspect
|
64
|
+
opts = options rescue {}
|
65
|
+
|
66
|
+
parts = ["#<Candle::EmbeddingModel"]
|
67
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
68
|
+
parts << "type=#{opts["model_type"]}" if opts["model_type"]
|
69
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
70
|
+
parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
|
71
|
+
|
72
|
+
parts.join(" ") + ">"
|
73
|
+
end
|
31
74
|
end
|
32
75
|
end
|
data/lib/candle/llm.rb
CHANGED
@@ -189,6 +189,45 @@ module Candle
|
|
189
189
|
prompt = apply_chat_template(messages)
|
190
190
|
generate_stream(prompt, **options, &block)
|
191
191
|
end
|
192
|
+
|
193
|
+
# Inspect method for debugging and exploration
|
194
|
+
def inspect
|
195
|
+
opts = options rescue {}
|
196
|
+
|
197
|
+
# Extract key information
|
198
|
+
model_type = opts["model_type"] || "Unknown"
|
199
|
+
device = opts["device"] || self.device.to_s rescue "unknown"
|
200
|
+
|
201
|
+
# Build the inspect string
|
202
|
+
parts = ["#<Candle::LLM"]
|
203
|
+
|
204
|
+
# Add base model or model_id
|
205
|
+
if opts["base_model"]
|
206
|
+
parts << "model=#{opts["base_model"]}"
|
207
|
+
elsif opts["model_id"]
|
208
|
+
parts << "model=#{opts["model_id"]}"
|
209
|
+
elsif respond_to?(:model_id)
|
210
|
+
parts << "model=#{model_id}"
|
211
|
+
end
|
212
|
+
|
213
|
+
# Add GGUF file if present
|
214
|
+
if opts["gguf_file"]
|
215
|
+
parts << "gguf=#{opts["gguf_file"]}"
|
216
|
+
end
|
217
|
+
|
218
|
+
# Add device
|
219
|
+
parts << "device=#{device}"
|
220
|
+
|
221
|
+
# Add model type
|
222
|
+
parts << "type=#{model_type}"
|
223
|
+
|
224
|
+
# Add architecture for GGUF models
|
225
|
+
if opts["architecture"]
|
226
|
+
parts << "arch=#{opts["architecture"]}"
|
227
|
+
end
|
228
|
+
|
229
|
+
parts.join(" ") + ">"
|
230
|
+
end
|
192
231
|
|
193
232
|
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
194
233
|
begin
|
@@ -206,7 +245,7 @@ module Candle
|
|
206
245
|
end
|
207
246
|
end
|
208
247
|
|
209
|
-
def self.from_pretrained(model_id, device: Candle::Device.
|
248
|
+
def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
|
210
249
|
model_str = if gguf_file
|
211
250
|
"#{model_id}@#{gguf_file}"
|
212
251
|
else
|
@@ -393,5 +432,28 @@ module Candle
|
|
393
432
|
}
|
394
433
|
new(defaults.merge(opts))
|
395
434
|
end
|
435
|
+
|
436
|
+
# Inspect method for debugging and exploration
|
437
|
+
def inspect
|
438
|
+
opts = options rescue {}
|
439
|
+
|
440
|
+
parts = ["#<Candle::GenerationConfig"]
|
441
|
+
|
442
|
+
# Add key configuration parameters
|
443
|
+
parts << "temp=#{opts["temperature"]}" if opts["temperature"]
|
444
|
+
parts << "max=#{opts["max_length"]}" if opts["max_length"]
|
445
|
+
parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
|
446
|
+
parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
|
447
|
+
parts << "seed=#{opts["seed"]}" if opts["seed"]
|
448
|
+
|
449
|
+
# Add flags
|
450
|
+
flags = []
|
451
|
+
flags << "debug" if opts["debug_tokens"]
|
452
|
+
flags << "constraint" if opts["has_constraint"]
|
453
|
+
flags << "stop_on_match" if opts["stop_on_match"]
|
454
|
+
parts << "flags=[#{flags.join(",")}]" if flags.any?
|
455
|
+
|
456
|
+
parts.join(" ") + ">"
|
457
|
+
end
|
396
458
|
end
|
397
459
|
end
|