red-candle 1.1.2 → 1.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +39 -45
- data/Rakefile +79 -88
- data/ext/candle/src/lib.rs +2 -4
- data/ext/candle/src/llm/quantized_gguf.rs +1 -1
- data/ext/candle/src/ruby/device.rs +30 -0
- data/ext/candle/src/ruby/embedding_model.rs +74 -28
- data/ext/candle/src/ruby/llm.rs +96 -1
- data/ext/candle/src/ruby/mod.rs +2 -0
- data/ext/candle/src/{ner.rs → ruby/ner.rs} +47 -15
- data/ext/candle/src/{reranker.rs → ruby/reranker.rs} +24 -2
- data/ext/candle/src/ruby/tensor.rs +101 -26
- data/ext/candle/src/ruby/tokenizer.rs +60 -3
- data/lib/candle/device_utils.rb +3 -15
- data/lib/candle/embedding_model.rb +44 -1
- data/lib/candle/llm.rb +63 -1
- data/lib/candle/ner.rb +34 -22
- data/lib/candle/reranker.rb +20 -1
- data/lib/candle/tensor.rb +15 -0
- data/lib/candle/version.rb +1 -1
- metadata +18 -4
@@ -6,7 +6,7 @@ use crate::ruby::{
|
|
6
6
|
utils::{actual_dim, actual_index},
|
7
7
|
};
|
8
8
|
use crate::ruby::{DType, Device, Result};
|
9
|
-
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
|
9
|
+
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
|
10
10
|
|
11
11
|
#[derive(Clone, Debug)]
|
12
12
|
#[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
|
@@ -21,30 +21,108 @@ impl std::ops::Deref for Tensor {
|
|
21
21
|
}
|
22
22
|
}
|
23
23
|
|
24
|
+
// Helper functions for tensor operations
|
25
|
+
impl Tensor {
|
26
|
+
/// Check if device is Metal
|
27
|
+
fn is_metal_device(device: &CoreDevice) -> bool {
|
28
|
+
matches!(device.location(), DeviceLocation::Metal { .. })
|
29
|
+
}
|
30
|
+
|
31
|
+
/// Convert tensor to target dtype, handling Metal limitations
|
32
|
+
fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
|
33
|
+
if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
|
34
|
+
// Move to CPU first to avoid Metal conversion limitations
|
35
|
+
self.0
|
36
|
+
.to_device(&CoreDevice::Cpu)
|
37
|
+
.map_err(wrap_candle_err)?
|
38
|
+
.to_dtype(target_dtype)
|
39
|
+
.map_err(wrap_candle_err)
|
40
|
+
} else {
|
41
|
+
// Direct conversion for CPU or when dtype matches
|
42
|
+
self.0
|
43
|
+
.to_dtype(target_dtype)
|
44
|
+
.map_err(wrap_candle_err)
|
45
|
+
}
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
24
49
|
impl Tensor {
|
25
50
|
pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
|
26
51
|
let dtype = dtype
|
27
52
|
.map(|dtype| DType::from_rbobject(dtype))
|
28
53
|
.unwrap_or(Ok(DType(CoreDType::F32)))?;
|
29
|
-
let device = device.unwrap_or(Device::
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
.
|
40
|
-
|
54
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
55
|
+
|
56
|
+
// Create tensor based on target dtype to avoid conversion issues on Metal
|
57
|
+
let tensor = match dtype.0 {
|
58
|
+
CoreDType::F32 => {
|
59
|
+
// Convert to f32 directly to avoid F64->F32 conversion on Metal
|
60
|
+
let array: Vec<f32> = array
|
61
|
+
.into_iter()
|
62
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
|
63
|
+
.collect::<Result<Vec<_>>>()?;
|
64
|
+
let len = array.len();
|
65
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
66
|
+
}
|
67
|
+
CoreDType::F64 => {
|
68
|
+
let array: Vec<f64> = array
|
69
|
+
.into_iter()
|
70
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
71
|
+
.collect::<Result<Vec<_>>>()?;
|
72
|
+
let len = array.len();
|
73
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
74
|
+
}
|
75
|
+
CoreDType::I64 => {
|
76
|
+
// Convert to i64 directly to avoid conversion issues on Metal
|
77
|
+
let array: Vec<i64> = array
|
78
|
+
.into_iter()
|
79
|
+
.map(|v| {
|
80
|
+
// Try integer first, then float
|
81
|
+
if let Ok(i) = <i64>::try_convert(v) {
|
82
|
+
Ok(i)
|
83
|
+
} else if let Ok(f) = magnus::Float::try_convert(v) {
|
84
|
+
Ok(f.to_f64() as i64)
|
85
|
+
} else {
|
86
|
+
Err(magnus::Error::new(
|
87
|
+
magnus::exception::type_error(),
|
88
|
+
"Cannot convert to i64"
|
89
|
+
))
|
90
|
+
}
|
91
|
+
})
|
92
|
+
.collect::<Result<Vec<_>>>()?;
|
93
|
+
let len = array.len();
|
94
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
95
|
+
}
|
96
|
+
_ => {
|
97
|
+
// For other dtypes, create on CPU first if on Metal, then convert
|
98
|
+
let cpu_device = CoreDevice::Cpu;
|
99
|
+
let use_cpu = Self::is_metal_device(&device);
|
100
|
+
let target_device = if use_cpu { &cpu_device } else { &device };
|
101
|
+
|
102
|
+
let array: Vec<f64> = array
|
103
|
+
.into_iter()
|
104
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
105
|
+
.collect::<Result<Vec<_>>>()?;
|
106
|
+
let tensor = CoreTensor::new(array.as_slice(), target_device)
|
107
|
+
.map_err(wrap_candle_err)?
|
108
|
+
.to_dtype(dtype.0)
|
109
|
+
.map_err(wrap_candle_err)?;
|
110
|
+
|
111
|
+
// Move to target device if needed
|
112
|
+
if use_cpu {
|
113
|
+
tensor.to_device(&device).map_err(wrap_candle_err)?
|
114
|
+
} else {
|
115
|
+
tensor
|
116
|
+
}
|
117
|
+
}
|
118
|
+
};
|
119
|
+
|
120
|
+
Ok(Self(tensor))
|
41
121
|
}
|
42
122
|
|
43
123
|
pub fn values(&self) -> Result<Vec<f64>> {
|
44
|
-
let
|
45
|
-
|
46
|
-
.to_dtype(CoreDType::F64)
|
47
|
-
.map_err(wrap_candle_err)?
|
124
|
+
let tensor = self.safe_to_dtype(CoreDType::F64)?;
|
125
|
+
let values = tensor
|
48
126
|
.flatten_all()
|
49
127
|
.map_err(wrap_candle_err)?
|
50
128
|
.to_vec1()
|
@@ -92,11 +170,8 @@ impl Tensor {
|
|
92
170
|
}
|
93
171
|
_ => {
|
94
172
|
// For other dtypes, convert to F64 first
|
95
|
-
let
|
96
|
-
|
97
|
-
.map_err(wrap_candle_err)?
|
98
|
-
.to_vec0()
|
99
|
-
.map_err(wrap_candle_err)?;
|
173
|
+
let tensor = self.safe_to_dtype(CoreDType::F64)?;
|
174
|
+
let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
|
100
175
|
Ok(val)
|
101
176
|
}
|
102
177
|
}
|
@@ -541,7 +616,7 @@ impl Tensor {
|
|
541
616
|
/// Creates a new tensor with random values.
|
542
617
|
/// &RETURNS&: Tensor
|
543
618
|
pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
544
|
-
let device = device.unwrap_or(Device::
|
619
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
545
620
|
Ok(Self(
|
546
621
|
CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
547
622
|
))
|
@@ -550,7 +625,7 @@ impl Tensor {
|
|
550
625
|
/// Creates a new tensor with random values from a normal distribution.
|
551
626
|
/// &RETURNS&: Tensor
|
552
627
|
pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
553
|
-
let device = device.unwrap_or(Device::
|
628
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
554
629
|
Ok(Self(
|
555
630
|
CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
556
631
|
))
|
@@ -559,7 +634,7 @@ impl Tensor {
|
|
559
634
|
/// Creates a new tensor filled with ones.
|
560
635
|
/// &RETURNS&: Tensor
|
561
636
|
pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
562
|
-
let device = device.unwrap_or(Device::
|
637
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
563
638
|
Ok(Self(
|
564
639
|
CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
565
640
|
))
|
@@ -567,7 +642,7 @@ impl Tensor {
|
|
567
642
|
/// Creates a new tensor filled with zeros.
|
568
643
|
/// &RETURNS&: Tensor
|
569
644
|
pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
570
|
-
let device = device.unwrap_or(Device::
|
645
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
571
646
|
Ok(Self(
|
572
647
|
CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
573
648
|
))
|
@@ -105,8 +105,8 @@ impl Tokenizer {
|
|
105
105
|
}
|
106
106
|
|
107
107
|
let hash = RHash::new();
|
108
|
-
hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
-
hash.aset("tokens", RArray::from_vec(tokens))?;
|
108
|
+
hash.aset(magnus::Symbol::new("ids"), RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
+
hash.aset(magnus::Symbol::new("tokens"), RArray::from_vec(tokens))?;
|
110
110
|
|
111
111
|
Ok(hash)
|
112
112
|
}
|
@@ -236,9 +236,65 @@ impl Tokenizer {
|
|
236
236
|
Ok(hash)
|
237
237
|
}
|
238
238
|
|
239
|
+
/// Get tokenizer options as a hash
|
240
|
+
pub fn options(&self) -> Result<RHash> {
|
241
|
+
let hash = RHash::new();
|
242
|
+
|
243
|
+
// Get vocab size
|
244
|
+
hash.aset("vocab_size", self.vocab_size(Some(true)))?;
|
245
|
+
hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
|
246
|
+
|
247
|
+
// Get special tokens info
|
248
|
+
let special_tokens = self.get_special_tokens()?;
|
249
|
+
hash.aset("special_tokens", special_tokens)?;
|
250
|
+
|
251
|
+
// Get padding/truncation info if available
|
252
|
+
let inner_tokenizer = self.0.inner();
|
253
|
+
|
254
|
+
// Check if padding is enabled
|
255
|
+
if let Some(_padding) = inner_tokenizer.get_padding() {
|
256
|
+
let padding_info = RHash::new();
|
257
|
+
padding_info.aset("enabled", true)?;
|
258
|
+
// Note: We can't easily extract all padding params from the tokenizers library
|
259
|
+
// but we can indicate it's enabled
|
260
|
+
hash.aset("padding", padding_info)?;
|
261
|
+
}
|
262
|
+
|
263
|
+
// Check if truncation is enabled
|
264
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
265
|
+
let truncation_info = RHash::new();
|
266
|
+
truncation_info.aset("enabled", true)?;
|
267
|
+
truncation_info.aset("max_length", truncation.max_length)?;
|
268
|
+
hash.aset("truncation", truncation_info)?;
|
269
|
+
}
|
270
|
+
|
271
|
+
Ok(hash)
|
272
|
+
}
|
273
|
+
|
239
274
|
/// String representation
|
240
275
|
pub fn inspect(&self) -> String {
|
241
|
-
|
276
|
+
let vocab_size = self.vocab_size(Some(true));
|
277
|
+
let special_tokens = self.get_special_tokens()
|
278
|
+
.ok()
|
279
|
+
.map(|h| h.len())
|
280
|
+
.unwrap_or(0);
|
281
|
+
|
282
|
+
let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
|
283
|
+
|
284
|
+
if special_tokens > 0 {
|
285
|
+
parts.push(format!("special_tokens={}", special_tokens));
|
286
|
+
}
|
287
|
+
|
288
|
+
// Check for padding/truncation
|
289
|
+
let inner_tokenizer = self.0.inner();
|
290
|
+
if inner_tokenizer.get_padding().is_some() {
|
291
|
+
parts.push("padding=enabled".to_string());
|
292
|
+
}
|
293
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
294
|
+
parts.push(format!("truncation={}", truncation.max_length));
|
295
|
+
}
|
296
|
+
|
297
|
+
parts.join(" ") + ">"
|
242
298
|
}
|
243
299
|
}
|
244
300
|
|
@@ -262,6 +318,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
262
318
|
tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
|
263
319
|
tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
|
264
320
|
tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
|
321
|
+
tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
|
265
322
|
tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
|
266
323
|
tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
|
267
324
|
|
data/lib/candle/device_utils.rb
CHANGED
@@ -1,22 +1,10 @@
|
|
1
1
|
module Candle
|
2
2
|
module DeviceUtils
|
3
|
+
# @deprecated Use {Candle::Device.best} instead
|
3
4
|
# Get the best available device (Metal > CUDA > CPU)
|
4
5
|
def self.best_device
|
5
|
-
|
6
|
-
|
7
|
-
# Try Metal first (for Mac users)
|
8
|
-
Device.metal
|
9
|
-
rescue
|
10
|
-
# :nocov:
|
11
|
-
begin
|
12
|
-
# Try CUDA next (for NVIDIA GPU users)
|
13
|
-
Device.cuda
|
14
|
-
rescue
|
15
|
-
# Fall back to CPU
|
16
|
-
Device.cpu
|
17
|
-
end
|
18
|
-
# :nocov:
|
19
|
-
end
|
6
|
+
warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
|
7
|
+
Device.best
|
20
8
|
end
|
21
9
|
end
|
22
10
|
end
|
@@ -9,7 +9,36 @@ module Candle
|
|
9
9
|
# Default embedding model type
|
10
10
|
DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
|
11
11
|
|
12
|
+
# Load a pre-trained embedding model from HuggingFace
|
13
|
+
# @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
|
14
|
+
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
15
|
+
# @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
|
16
|
+
# @param model_type [String, nil] The type of embedding model (auto-detected if nil)
|
17
|
+
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
|
+
# @return [EmbeddingModel] A new EmbeddingModel instance
|
19
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
|
20
|
+
# Auto-detect model type based on model_id if not provided
|
21
|
+
if model_type.nil?
|
22
|
+
model_type = case model_id.downcase
|
23
|
+
when /jina/
|
24
|
+
"jina_bert"
|
25
|
+
when /distilbert/
|
26
|
+
"distilbert"
|
27
|
+
when /minilm/
|
28
|
+
"minilm"
|
29
|
+
else
|
30
|
+
"standard_bert"
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
# Use model_id as tokenizer if not specified (usually what you want)
|
35
|
+
tokenizer_id = tokenizer || model_id
|
36
|
+
|
37
|
+
_create(model_id, tokenizer_id, device, model_type, embedding_size)
|
38
|
+
end
|
39
|
+
|
12
40
|
# Constructor for creating a new EmbeddingModel with optional parameters
|
41
|
+
# @deprecated Use {.from_pretrained} instead
|
13
42
|
# @param model_path [String, nil] The path to the model on Hugging Face
|
14
43
|
# @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
|
15
44
|
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
@@ -17,9 +46,10 @@ module Candle
|
|
17
46
|
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
47
|
def self.new(model_path: DEFAULT_MODEL_PATH,
|
19
48
|
tokenizer_path: DEFAULT_TOKENIZER_PATH,
|
20
|
-
device: Candle::Device.
|
49
|
+
device: Candle::Device.best,
|
21
50
|
model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
|
22
51
|
embedding_size: nil)
|
52
|
+
$stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
|
23
53
|
_create(model_path, tokenizer_path, device, model_type, embedding_size)
|
24
54
|
end
|
25
55
|
# Returns the embedding for a string using the specified pooling method.
|
@@ -28,5 +58,18 @@ module Candle
|
|
28
58
|
def embedding(str, pooling_method: "pooled_normalized")
|
29
59
|
_embedding(str, pooling_method)
|
30
60
|
end
|
61
|
+
|
62
|
+
# Improved inspect method
|
63
|
+
def inspect
|
64
|
+
opts = options rescue {}
|
65
|
+
|
66
|
+
parts = ["#<Candle::EmbeddingModel"]
|
67
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
68
|
+
parts << "type=#{opts["model_type"]}" if opts["model_type"]
|
69
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
70
|
+
parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
|
71
|
+
|
72
|
+
parts.join(" ") + ">"
|
73
|
+
end
|
31
74
|
end
|
32
75
|
end
|
data/lib/candle/llm.rb
CHANGED
@@ -189,6 +189,45 @@ module Candle
|
|
189
189
|
prompt = apply_chat_template(messages)
|
190
190
|
generate_stream(prompt, **options, &block)
|
191
191
|
end
|
192
|
+
|
193
|
+
# Inspect method for debugging and exploration
|
194
|
+
def inspect
|
195
|
+
opts = options rescue {}
|
196
|
+
|
197
|
+
# Extract key information
|
198
|
+
model_type = opts["model_type"] || "Unknown"
|
199
|
+
device = opts["device"] || self.device.to_s rescue "unknown"
|
200
|
+
|
201
|
+
# Build the inspect string
|
202
|
+
parts = ["#<Candle::LLM"]
|
203
|
+
|
204
|
+
# Add base model or model_id
|
205
|
+
if opts["base_model"]
|
206
|
+
parts << "model=#{opts["base_model"]}"
|
207
|
+
elsif opts["model_id"]
|
208
|
+
parts << "model=#{opts["model_id"]}"
|
209
|
+
elsif respond_to?(:model_id)
|
210
|
+
parts << "model=#{model_id}"
|
211
|
+
end
|
212
|
+
|
213
|
+
# Add GGUF file if present
|
214
|
+
if opts["gguf_file"]
|
215
|
+
parts << "gguf=#{opts["gguf_file"]}"
|
216
|
+
end
|
217
|
+
|
218
|
+
# Add device
|
219
|
+
parts << "device=#{device}"
|
220
|
+
|
221
|
+
# Add model type
|
222
|
+
parts << "type=#{model_type}"
|
223
|
+
|
224
|
+
# Add architecture for GGUF models
|
225
|
+
if opts["architecture"]
|
226
|
+
parts << "arch=#{opts["architecture"]}"
|
227
|
+
end
|
228
|
+
|
229
|
+
parts.join(" ") + ">"
|
230
|
+
end
|
192
231
|
|
193
232
|
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
194
233
|
begin
|
@@ -206,7 +245,7 @@ module Candle
|
|
206
245
|
end
|
207
246
|
end
|
208
247
|
|
209
|
-
def self.from_pretrained(model_id, device: Candle::Device.
|
248
|
+
def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
|
210
249
|
model_str = if gguf_file
|
211
250
|
"#{model_id}@#{gguf_file}"
|
212
251
|
else
|
@@ -393,5 +432,28 @@ module Candle
|
|
393
432
|
}
|
394
433
|
new(defaults.merge(opts))
|
395
434
|
end
|
435
|
+
|
436
|
+
# Inspect method for debugging and exploration
|
437
|
+
def inspect
|
438
|
+
opts = options rescue {}
|
439
|
+
|
440
|
+
parts = ["#<Candle::GenerationConfig"]
|
441
|
+
|
442
|
+
# Add key configuration parameters
|
443
|
+
parts << "temp=#{opts["temperature"]}" if opts["temperature"]
|
444
|
+
parts << "max=#{opts["max_length"]}" if opts["max_length"]
|
445
|
+
parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
|
446
|
+
parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
|
447
|
+
parts << "seed=#{opts["seed"]}" if opts["seed"]
|
448
|
+
|
449
|
+
# Add flags
|
450
|
+
flags = []
|
451
|
+
flags << "debug" if opts["debug_tokens"]
|
452
|
+
flags << "constraint" if opts["has_constraint"]
|
453
|
+
flags << "stop_on_match" if opts["stop_on_match"]
|
454
|
+
parts << "flags=[#{flags.join(",")}]" if flags.any?
|
455
|
+
|
456
|
+
parts.join(" ") + ">"
|
457
|
+
end
|
396
458
|
end
|
397
459
|
end
|
data/lib/candle/ner.rb
CHANGED
@@ -30,10 +30,10 @@ module Candle
|
|
30
30
|
# Load a pre-trained NER model from HuggingFace
|
31
31
|
#
|
32
32
|
# @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
|
33
|
-
# @param device [Device
|
33
|
+
# @param device [Device] Device to run on (defaults to best available)
|
34
34
|
# @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
|
35
35
|
# @return [NER] NER instance
|
36
|
-
def from_pretrained(model_id, device:
|
36
|
+
def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
|
37
37
|
new(model_id, device, tokenizer)
|
38
38
|
end
|
39
39
|
|
@@ -112,7 +112,7 @@ module Candle
|
|
112
112
|
# @return [Array<Hash>] Filtered entities of the specified type
|
113
113
|
def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
|
114
114
|
entities = extract_entities(text, confidence_threshold: confidence_threshold)
|
115
|
-
entities.select { |e| e[
|
115
|
+
entities.select { |e| e[:label] == entity_type.upcase }
|
116
116
|
end
|
117
117
|
|
118
118
|
# Analyze text and return both entities and token predictions
|
@@ -137,12 +137,12 @@ module Candle
|
|
137
137
|
return text if entities.empty?
|
138
138
|
|
139
139
|
# Sort by start position (reverse for easier insertion)
|
140
|
-
entities.sort_by! { |e| -e[
|
140
|
+
entities.sort_by! { |e| -e[:start] }
|
141
141
|
|
142
142
|
result = text.dup
|
143
143
|
entities.each do |entity|
|
144
|
-
label = "[#{entity[
|
145
|
-
result.insert(entity[
|
144
|
+
label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
|
145
|
+
result.insert(entity[:end], label)
|
146
146
|
end
|
147
147
|
|
148
148
|
result
|
@@ -152,7 +152,19 @@ module Candle
|
|
152
152
|
#
|
153
153
|
# @return [String] Model description
|
154
154
|
def inspect
|
155
|
-
|
155
|
+
opts = options rescue {}
|
156
|
+
|
157
|
+
parts = ["#<Candle::NER"]
|
158
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
159
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
160
|
+
parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
|
161
|
+
|
162
|
+
if opts["entity_types"] && !opts["entity_types"].empty?
|
163
|
+
types = opts["entity_types"].sort.join(",")
|
164
|
+
parts << "types=#{types}"
|
165
|
+
end
|
166
|
+
|
167
|
+
parts.join(" ") + ">"
|
156
168
|
end
|
157
169
|
|
158
170
|
alias to_s inspect
|
@@ -186,12 +198,12 @@ module Candle
|
|
186
198
|
match_end = $~.offset(0)[1]
|
187
199
|
|
188
200
|
entities << {
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
201
|
+
text: match_text,
|
202
|
+
label: @entity_type,
|
203
|
+
start: match_start,
|
204
|
+
end: match_end,
|
205
|
+
confidence: 1.0,
|
206
|
+
source: "pattern"
|
195
207
|
}
|
196
208
|
end
|
197
209
|
end
|
@@ -242,12 +254,12 @@ module Candle
|
|
242
254
|
|
243
255
|
if word_boundary?(prev_char) && word_boundary?(next_char)
|
244
256
|
entities << {
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
257
|
+
text: text[idx, pattern.length],
|
258
|
+
label: @entity_type,
|
259
|
+
start: idx,
|
260
|
+
end: idx + pattern.length,
|
261
|
+
confidence: 1.0,
|
262
|
+
source: "gazetteer"
|
251
263
|
}
|
252
264
|
end
|
253
265
|
|
@@ -327,19 +339,19 @@ module Candle
|
|
327
339
|
|
328
340
|
def merge_entities(entities)
|
329
341
|
# Sort by start position and confidence (descending)
|
330
|
-
sorted = entities.sort_by { |e| [e[
|
342
|
+
sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
|
331
343
|
|
332
344
|
merged = []
|
333
345
|
sorted.each do |entity|
|
334
346
|
# Check if entity overlaps with any already merged
|
335
347
|
overlaps = merged.any? do |existing|
|
336
|
-
entity[
|
348
|
+
entity[:start] < existing[:end] && entity[:end] > existing[:start]
|
337
349
|
end
|
338
350
|
|
339
351
|
merged << entity unless overlaps
|
340
352
|
end
|
341
353
|
|
342
|
-
merged.sort_by { |e| e[
|
354
|
+
merged.sort_by { |e| e[:start] }
|
343
355
|
end
|
344
356
|
end
|
345
357
|
end
|
data/lib/candle/reranker.rb
CHANGED
@@ -3,10 +3,20 @@ module Candle
|
|
3
3
|
# Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
|
4
4
|
DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
5
5
|
|
6
|
+
# Load a pre-trained reranker model from HuggingFace
|
7
|
+
# @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
|
8
|
+
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
9
|
+
# @return [Reranker] A new Reranker instance
|
10
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
|
11
|
+
_create(model_id, device)
|
12
|
+
end
|
13
|
+
|
6
14
|
# Constructor for creating a new Reranker with optional parameters
|
15
|
+
# @deprecated Use {.from_pretrained} instead
|
7
16
|
# @param model_path [String, nil] The path to the model on Hugging Face
|
8
17
|
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
9
|
-
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.
|
18
|
+
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best)
|
19
|
+
$stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
|
10
20
|
_create(model_path, device)
|
11
21
|
end
|
12
22
|
|
@@ -20,5 +30,14 @@ module Candle
|
|
20
30
|
{ doc_id: doc_id, score: score, text: doc }
|
21
31
|
}
|
22
32
|
end
|
33
|
+
|
34
|
+
# Improved inspect method
|
35
|
+
def inspect
|
36
|
+
opts = options rescue {}
|
37
|
+
parts = ["#<Candle::Reranker"]
|
38
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
39
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
40
|
+
parts.join(" ") + ">"
|
41
|
+
end
|
23
42
|
end
|
24
43
|
end
|
data/lib/candle/tensor.rb
CHANGED
@@ -51,6 +51,21 @@ module Candle
|
|
51
51
|
to_f.to_i
|
52
52
|
end
|
53
53
|
|
54
|
+
# Improved inspect method showing shape, dtype, and device
|
55
|
+
def inspect
|
56
|
+
shape_str = shape.join("x")
|
57
|
+
|
58
|
+
parts = ["#<Candle::Tensor"]
|
59
|
+
parts << "shape=#{shape_str}"
|
60
|
+
parts << "dtype=#{dtype}"
|
61
|
+
parts << "device=#{device}"
|
62
|
+
|
63
|
+
# Add element count for clarity
|
64
|
+
parts << "elements=#{elem_count}"
|
65
|
+
|
66
|
+
parts.join(" ") + ">"
|
67
|
+
end
|
68
|
+
|
54
69
|
|
55
70
|
# Override class methods to support keyword arguments for device
|
56
71
|
class << self
|
data/lib/candle/version.rb
CHANGED