red-candle 1.1.2 → 1.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +40 -46
- data/Rakefile +79 -88
- data/ext/candle/src/lib.rs +2 -4
- data/ext/candle/src/llm/quantized_gguf.rs +1 -1
- data/ext/candle/src/ruby/device.rs +30 -0
- data/ext/candle/src/ruby/embedding_model.rs +74 -28
- data/ext/candle/src/ruby/llm.rs +96 -1
- data/ext/candle/src/ruby/mod.rs +2 -0
- data/ext/candle/src/{ner.rs → ruby/ner.rs} +47 -15
- data/ext/candle/src/{reranker.rs → ruby/reranker.rs} +24 -2
- data/ext/candle/src/ruby/tensor.rs +101 -26
- data/ext/candle/src/ruby/tokenizer.rs +60 -3
- data/lib/candle/device_utils.rb +3 -15
- data/lib/candle/embedding_model.rb +44 -1
- data/lib/candle/llm.rb +63 -1
- data/lib/candle/ner.rb +45 -22
- data/lib/candle/reranker.rb +20 -1
- data/lib/candle/tensor.rb +15 -0
- data/lib/candle/version.rb +1 -1
- metadata +18 -4
@@ -6,7 +6,7 @@ use crate::ruby::{
|
|
6
6
|
utils::{actual_dim, actual_index},
|
7
7
|
};
|
8
8
|
use crate::ruby::{DType, Device, Result};
|
9
|
-
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
|
9
|
+
use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
|
10
10
|
|
11
11
|
#[derive(Clone, Debug)]
|
12
12
|
#[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
|
@@ -21,30 +21,108 @@ impl std::ops::Deref for Tensor {
|
|
21
21
|
}
|
22
22
|
}
|
23
23
|
|
24
|
+
// Helper functions for tensor operations
|
25
|
+
impl Tensor {
|
26
|
+
/// Check if device is Metal
|
27
|
+
fn is_metal_device(device: &CoreDevice) -> bool {
|
28
|
+
matches!(device.location(), DeviceLocation::Metal { .. })
|
29
|
+
}
|
30
|
+
|
31
|
+
/// Convert tensor to target dtype, handling Metal limitations
|
32
|
+
fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
|
33
|
+
if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
|
34
|
+
// Move to CPU first to avoid Metal conversion limitations
|
35
|
+
self.0
|
36
|
+
.to_device(&CoreDevice::Cpu)
|
37
|
+
.map_err(wrap_candle_err)?
|
38
|
+
.to_dtype(target_dtype)
|
39
|
+
.map_err(wrap_candle_err)
|
40
|
+
} else {
|
41
|
+
// Direct conversion for CPU or when dtype matches
|
42
|
+
self.0
|
43
|
+
.to_dtype(target_dtype)
|
44
|
+
.map_err(wrap_candle_err)
|
45
|
+
}
|
46
|
+
}
|
47
|
+
}
|
48
|
+
|
24
49
|
impl Tensor {
|
25
50
|
pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
|
26
51
|
let dtype = dtype
|
27
52
|
.map(|dtype| DType::from_rbobject(dtype))
|
28
53
|
.unwrap_or(Ok(DType(CoreDType::F32)))?;
|
29
|
-
let device = device.unwrap_or(Device::
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
.
|
40
|
-
|
54
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
55
|
+
|
56
|
+
// Create tensor based on target dtype to avoid conversion issues on Metal
|
57
|
+
let tensor = match dtype.0 {
|
58
|
+
CoreDType::F32 => {
|
59
|
+
// Convert to f32 directly to avoid F64->F32 conversion on Metal
|
60
|
+
let array: Vec<f32> = array
|
61
|
+
.into_iter()
|
62
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
|
63
|
+
.collect::<Result<Vec<_>>>()?;
|
64
|
+
let len = array.len();
|
65
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
66
|
+
}
|
67
|
+
CoreDType::F64 => {
|
68
|
+
let array: Vec<f64> = array
|
69
|
+
.into_iter()
|
70
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
71
|
+
.collect::<Result<Vec<_>>>()?;
|
72
|
+
let len = array.len();
|
73
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
74
|
+
}
|
75
|
+
CoreDType::I64 => {
|
76
|
+
// Convert to i64 directly to avoid conversion issues on Metal
|
77
|
+
let array: Vec<i64> = array
|
78
|
+
.into_iter()
|
79
|
+
.map(|v| {
|
80
|
+
// Try integer first, then float
|
81
|
+
if let Ok(i) = <i64>::try_convert(v) {
|
82
|
+
Ok(i)
|
83
|
+
} else if let Ok(f) = magnus::Float::try_convert(v) {
|
84
|
+
Ok(f.to_f64() as i64)
|
85
|
+
} else {
|
86
|
+
Err(magnus::Error::new(
|
87
|
+
magnus::exception::type_error(),
|
88
|
+
"Cannot convert to i64"
|
89
|
+
))
|
90
|
+
}
|
91
|
+
})
|
92
|
+
.collect::<Result<Vec<_>>>()?;
|
93
|
+
let len = array.len();
|
94
|
+
CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
|
95
|
+
}
|
96
|
+
_ => {
|
97
|
+
// For other dtypes, create on CPU first if on Metal, then convert
|
98
|
+
let cpu_device = CoreDevice::Cpu;
|
99
|
+
let use_cpu = Self::is_metal_device(&device);
|
100
|
+
let target_device = if use_cpu { &cpu_device } else { &device };
|
101
|
+
|
102
|
+
let array: Vec<f64> = array
|
103
|
+
.into_iter()
|
104
|
+
.map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
|
105
|
+
.collect::<Result<Vec<_>>>()?;
|
106
|
+
let tensor = CoreTensor::new(array.as_slice(), target_device)
|
107
|
+
.map_err(wrap_candle_err)?
|
108
|
+
.to_dtype(dtype.0)
|
109
|
+
.map_err(wrap_candle_err)?;
|
110
|
+
|
111
|
+
// Move to target device if needed
|
112
|
+
if use_cpu {
|
113
|
+
tensor.to_device(&device).map_err(wrap_candle_err)?
|
114
|
+
} else {
|
115
|
+
tensor
|
116
|
+
}
|
117
|
+
}
|
118
|
+
};
|
119
|
+
|
120
|
+
Ok(Self(tensor))
|
41
121
|
}
|
42
122
|
|
43
123
|
pub fn values(&self) -> Result<Vec<f64>> {
|
44
|
-
let
|
45
|
-
|
46
|
-
.to_dtype(CoreDType::F64)
|
47
|
-
.map_err(wrap_candle_err)?
|
124
|
+
let tensor = self.safe_to_dtype(CoreDType::F64)?;
|
125
|
+
let values = tensor
|
48
126
|
.flatten_all()
|
49
127
|
.map_err(wrap_candle_err)?
|
50
128
|
.to_vec1()
|
@@ -92,11 +170,8 @@ impl Tensor {
|
|
92
170
|
}
|
93
171
|
_ => {
|
94
172
|
// For other dtypes, convert to F64 first
|
95
|
-
let
|
96
|
-
|
97
|
-
.map_err(wrap_candle_err)?
|
98
|
-
.to_vec0()
|
99
|
-
.map_err(wrap_candle_err)?;
|
173
|
+
let tensor = self.safe_to_dtype(CoreDType::F64)?;
|
174
|
+
let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
|
100
175
|
Ok(val)
|
101
176
|
}
|
102
177
|
}
|
@@ -541,7 +616,7 @@ impl Tensor {
|
|
541
616
|
/// Creates a new tensor with random values.
|
542
617
|
/// &RETURNS&: Tensor
|
543
618
|
pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
544
|
-
let device = device.unwrap_or(Device::
|
619
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
545
620
|
Ok(Self(
|
546
621
|
CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
547
622
|
))
|
@@ -550,7 +625,7 @@ impl Tensor {
|
|
550
625
|
/// Creates a new tensor with random values from a normal distribution.
|
551
626
|
/// &RETURNS&: Tensor
|
552
627
|
pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
553
|
-
let device = device.unwrap_or(Device::
|
628
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
554
629
|
Ok(Self(
|
555
630
|
CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
|
556
631
|
))
|
@@ -559,7 +634,7 @@ impl Tensor {
|
|
559
634
|
/// Creates a new tensor filled with ones.
|
560
635
|
/// &RETURNS&: Tensor
|
561
636
|
pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
562
|
-
let device = device.unwrap_or(Device::
|
637
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
563
638
|
Ok(Self(
|
564
639
|
CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
565
640
|
))
|
@@ -567,7 +642,7 @@ impl Tensor {
|
|
567
642
|
/// Creates a new tensor filled with zeros.
|
568
643
|
/// &RETURNS&: Tensor
|
569
644
|
pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
|
570
|
-
let device = device.unwrap_or(Device::
|
645
|
+
let device = device.unwrap_or(Device::best()).as_device()?;
|
571
646
|
Ok(Self(
|
572
647
|
CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
|
573
648
|
))
|
@@ -105,8 +105,8 @@ impl Tokenizer {
|
|
105
105
|
}
|
106
106
|
|
107
107
|
let hash = RHash::new();
|
108
|
-
hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
-
hash.aset("tokens", RArray::from_vec(tokens))?;
|
108
|
+
hash.aset(magnus::Symbol::new("ids"), RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
|
109
|
+
hash.aset(magnus::Symbol::new("tokens"), RArray::from_vec(tokens))?;
|
110
110
|
|
111
111
|
Ok(hash)
|
112
112
|
}
|
@@ -236,9 +236,65 @@ impl Tokenizer {
|
|
236
236
|
Ok(hash)
|
237
237
|
}
|
238
238
|
|
239
|
+
/// Get tokenizer options as a hash
|
240
|
+
pub fn options(&self) -> Result<RHash> {
|
241
|
+
let hash = RHash::new();
|
242
|
+
|
243
|
+
// Get vocab size
|
244
|
+
hash.aset("vocab_size", self.vocab_size(Some(true)))?;
|
245
|
+
hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
|
246
|
+
|
247
|
+
// Get special tokens info
|
248
|
+
let special_tokens = self.get_special_tokens()?;
|
249
|
+
hash.aset("special_tokens", special_tokens)?;
|
250
|
+
|
251
|
+
// Get padding/truncation info if available
|
252
|
+
let inner_tokenizer = self.0.inner();
|
253
|
+
|
254
|
+
// Check if padding is enabled
|
255
|
+
if let Some(_padding) = inner_tokenizer.get_padding() {
|
256
|
+
let padding_info = RHash::new();
|
257
|
+
padding_info.aset("enabled", true)?;
|
258
|
+
// Note: We can't easily extract all padding params from the tokenizers library
|
259
|
+
// but we can indicate it's enabled
|
260
|
+
hash.aset("padding", padding_info)?;
|
261
|
+
}
|
262
|
+
|
263
|
+
// Check if truncation is enabled
|
264
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
265
|
+
let truncation_info = RHash::new();
|
266
|
+
truncation_info.aset("enabled", true)?;
|
267
|
+
truncation_info.aset("max_length", truncation.max_length)?;
|
268
|
+
hash.aset("truncation", truncation_info)?;
|
269
|
+
}
|
270
|
+
|
271
|
+
Ok(hash)
|
272
|
+
}
|
273
|
+
|
239
274
|
/// String representation
|
240
275
|
pub fn inspect(&self) -> String {
|
241
|
-
|
276
|
+
let vocab_size = self.vocab_size(Some(true));
|
277
|
+
let special_tokens = self.get_special_tokens()
|
278
|
+
.ok()
|
279
|
+
.map(|h| h.len())
|
280
|
+
.unwrap_or(0);
|
281
|
+
|
282
|
+
let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
|
283
|
+
|
284
|
+
if special_tokens > 0 {
|
285
|
+
parts.push(format!("special_tokens={}", special_tokens));
|
286
|
+
}
|
287
|
+
|
288
|
+
// Check for padding/truncation
|
289
|
+
let inner_tokenizer = self.0.inner();
|
290
|
+
if inner_tokenizer.get_padding().is_some() {
|
291
|
+
parts.push("padding=enabled".to_string());
|
292
|
+
}
|
293
|
+
if let Some(truncation) = inner_tokenizer.get_truncation() {
|
294
|
+
parts.push(format!("truncation={}", truncation.max_length));
|
295
|
+
}
|
296
|
+
|
297
|
+
parts.join(" ") + ">"
|
242
298
|
}
|
243
299
|
}
|
244
300
|
|
@@ -262,6 +318,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
|
|
262
318
|
tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
|
263
319
|
tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
|
264
320
|
tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
|
321
|
+
tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
|
265
322
|
tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
|
266
323
|
tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
|
267
324
|
|
data/lib/candle/device_utils.rb
CHANGED
@@ -1,22 +1,10 @@
|
|
1
1
|
module Candle
|
2
2
|
module DeviceUtils
|
3
|
+
# @deprecated Use {Candle::Device.best} instead
|
3
4
|
# Get the best available device (Metal > CUDA > CPU)
|
4
5
|
def self.best_device
|
5
|
-
|
6
|
-
|
7
|
-
# Try Metal first (for Mac users)
|
8
|
-
Device.metal
|
9
|
-
rescue
|
10
|
-
# :nocov:
|
11
|
-
begin
|
12
|
-
# Try CUDA next (for NVIDIA GPU users)
|
13
|
-
Device.cuda
|
14
|
-
rescue
|
15
|
-
# Fall back to CPU
|
16
|
-
Device.cpu
|
17
|
-
end
|
18
|
-
# :nocov:
|
19
|
-
end
|
6
|
+
warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
|
7
|
+
Device.best
|
20
8
|
end
|
21
9
|
end
|
22
10
|
end
|
@@ -9,7 +9,36 @@ module Candle
|
|
9
9
|
# Default embedding model type
|
10
10
|
DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
|
11
11
|
|
12
|
+
# Load a pre-trained embedding model from HuggingFace
|
13
|
+
# @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
|
14
|
+
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
15
|
+
# @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
|
16
|
+
# @param model_type [String, nil] The type of embedding model (auto-detected if nil)
|
17
|
+
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
|
+
# @return [EmbeddingModel] A new EmbeddingModel instance
|
19
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
|
20
|
+
# Auto-detect model type based on model_id if not provided
|
21
|
+
if model_type.nil?
|
22
|
+
model_type = case model_id.downcase
|
23
|
+
when /jina/
|
24
|
+
"jina_bert"
|
25
|
+
when /distilbert/
|
26
|
+
"distilbert"
|
27
|
+
when /minilm/
|
28
|
+
"minilm"
|
29
|
+
else
|
30
|
+
"standard_bert"
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
# Use model_id as tokenizer if not specified (usually what you want)
|
35
|
+
tokenizer_id = tokenizer || model_id
|
36
|
+
|
37
|
+
_create(model_id, tokenizer_id, device, model_type, embedding_size)
|
38
|
+
end
|
39
|
+
|
12
40
|
# Constructor for creating a new EmbeddingModel with optional parameters
|
41
|
+
# @deprecated Use {.from_pretrained} instead
|
13
42
|
# @param model_path [String, nil] The path to the model on Hugging Face
|
14
43
|
# @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
|
15
44
|
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
@@ -17,9 +46,10 @@ module Candle
|
|
17
46
|
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
47
|
def self.new(model_path: DEFAULT_MODEL_PATH,
|
19
48
|
tokenizer_path: DEFAULT_TOKENIZER_PATH,
|
20
|
-
device: Candle::Device.
|
49
|
+
device: Candle::Device.best,
|
21
50
|
model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
|
22
51
|
embedding_size: nil)
|
52
|
+
$stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
|
23
53
|
_create(model_path, tokenizer_path, device, model_type, embedding_size)
|
24
54
|
end
|
25
55
|
# Returns the embedding for a string using the specified pooling method.
|
@@ -28,5 +58,18 @@ module Candle
|
|
28
58
|
def embedding(str, pooling_method: "pooled_normalized")
|
29
59
|
_embedding(str, pooling_method)
|
30
60
|
end
|
61
|
+
|
62
|
+
# Improved inspect method
|
63
|
+
def inspect
|
64
|
+
opts = options rescue {}
|
65
|
+
|
66
|
+
parts = ["#<Candle::EmbeddingModel"]
|
67
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
68
|
+
parts << "type=#{opts["model_type"]}" if opts["model_type"]
|
69
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
70
|
+
parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
|
71
|
+
|
72
|
+
parts.join(" ") + ">"
|
73
|
+
end
|
31
74
|
end
|
32
75
|
end
|
data/lib/candle/llm.rb
CHANGED
@@ -189,6 +189,45 @@ module Candle
|
|
189
189
|
prompt = apply_chat_template(messages)
|
190
190
|
generate_stream(prompt, **options, &block)
|
191
191
|
end
|
192
|
+
|
193
|
+
# Inspect method for debugging and exploration
|
194
|
+
def inspect
|
195
|
+
opts = options rescue {}
|
196
|
+
|
197
|
+
# Extract key information
|
198
|
+
model_type = opts["model_type"] || "Unknown"
|
199
|
+
device = opts["device"] || self.device.to_s rescue "unknown"
|
200
|
+
|
201
|
+
# Build the inspect string
|
202
|
+
parts = ["#<Candle::LLM"]
|
203
|
+
|
204
|
+
# Add base model or model_id
|
205
|
+
if opts["base_model"]
|
206
|
+
parts << "model=#{opts["base_model"]}"
|
207
|
+
elsif opts["model_id"]
|
208
|
+
parts << "model=#{opts["model_id"]}"
|
209
|
+
elsif respond_to?(:model_id)
|
210
|
+
parts << "model=#{model_id}"
|
211
|
+
end
|
212
|
+
|
213
|
+
# Add GGUF file if present
|
214
|
+
if opts["gguf_file"]
|
215
|
+
parts << "gguf=#{opts["gguf_file"]}"
|
216
|
+
end
|
217
|
+
|
218
|
+
# Add device
|
219
|
+
parts << "device=#{device}"
|
220
|
+
|
221
|
+
# Add model type
|
222
|
+
parts << "type=#{model_type}"
|
223
|
+
|
224
|
+
# Add architecture for GGUF models
|
225
|
+
if opts["architecture"]
|
226
|
+
parts << "arch=#{opts["architecture"]}"
|
227
|
+
end
|
228
|
+
|
229
|
+
parts.join(" ") + ">"
|
230
|
+
end
|
192
231
|
|
193
232
|
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
194
233
|
begin
|
@@ -206,7 +245,7 @@ module Candle
|
|
206
245
|
end
|
207
246
|
end
|
208
247
|
|
209
|
-
def self.from_pretrained(model_id, device: Candle::Device.
|
248
|
+
def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
|
210
249
|
model_str = if gguf_file
|
211
250
|
"#{model_id}@#{gguf_file}"
|
212
251
|
else
|
@@ -393,5 +432,28 @@ module Candle
|
|
393
432
|
}
|
394
433
|
new(defaults.merge(opts))
|
395
434
|
end
|
435
|
+
|
436
|
+
# Inspect method for debugging and exploration
|
437
|
+
def inspect
|
438
|
+
opts = options rescue {}
|
439
|
+
|
440
|
+
parts = ["#<Candle::GenerationConfig"]
|
441
|
+
|
442
|
+
# Add key configuration parameters
|
443
|
+
parts << "temp=#{opts["temperature"]}" if opts["temperature"]
|
444
|
+
parts << "max=#{opts["max_length"]}" if opts["max_length"]
|
445
|
+
parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
|
446
|
+
parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
|
447
|
+
parts << "seed=#{opts["seed"]}" if opts["seed"]
|
448
|
+
|
449
|
+
# Add flags
|
450
|
+
flags = []
|
451
|
+
flags << "debug" if opts["debug_tokens"]
|
452
|
+
flags << "constraint" if opts["has_constraint"]
|
453
|
+
flags << "stop_on_match" if opts["stop_on_match"]
|
454
|
+
parts << "flags=[#{flags.join(",")}]" if flags.any?
|
455
|
+
|
456
|
+
parts.join(" ") + ">"
|
457
|
+
end
|
396
458
|
end
|
397
459
|
end
|
data/lib/candle/ner.rb
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
# Pattern validation available but not forced
|
4
|
+
# require_relative 'pattern_validator' # Uncomment if needed
|
5
|
+
|
3
6
|
module Candle
|
4
7
|
# Named Entity Recognition (NER) for token classification
|
5
8
|
#
|
@@ -30,10 +33,10 @@ module Candle
|
|
30
33
|
# Load a pre-trained NER model from HuggingFace
|
31
34
|
#
|
32
35
|
# @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
|
33
|
-
# @param device [Device
|
36
|
+
# @param device [Device] Device to run on (defaults to best available)
|
34
37
|
# @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
|
35
38
|
# @return [NER] NER instance
|
36
|
-
def from_pretrained(model_id, device:
|
39
|
+
def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
|
37
40
|
new(model_id, device, tokenizer)
|
38
41
|
end
|
39
42
|
|
@@ -112,7 +115,7 @@ module Candle
|
|
112
115
|
# @return [Array<Hash>] Filtered entities of the specified type
|
113
116
|
def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
|
114
117
|
entities = extract_entities(text, confidence_threshold: confidence_threshold)
|
115
|
-
entities.select { |e| e[
|
118
|
+
entities.select { |e| e[:label] == entity_type.upcase }
|
116
119
|
end
|
117
120
|
|
118
121
|
# Analyze text and return both entities and token predictions
|
@@ -137,12 +140,12 @@ module Candle
|
|
137
140
|
return text if entities.empty?
|
138
141
|
|
139
142
|
# Sort by start position (reverse for easier insertion)
|
140
|
-
entities.sort_by! { |e| -e[
|
143
|
+
entities.sort_by! { |e| -e[:start] }
|
141
144
|
|
142
145
|
result = text.dup
|
143
146
|
entities.each do |entity|
|
144
|
-
label = "[#{entity[
|
145
|
-
result.insert(entity[
|
147
|
+
label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
|
148
|
+
result.insert(entity[:end], label)
|
146
149
|
end
|
147
150
|
|
148
151
|
result
|
@@ -152,7 +155,19 @@ module Candle
|
|
152
155
|
#
|
153
156
|
# @return [String] Model description
|
154
157
|
def inspect
|
155
|
-
|
158
|
+
opts = options rescue {}
|
159
|
+
|
160
|
+
parts = ["#<Candle::NER"]
|
161
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
162
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
163
|
+
parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
|
164
|
+
|
165
|
+
if opts["entity_types"] && !opts["entity_types"].empty?
|
166
|
+
types = opts["entity_types"].sort.join(",")
|
167
|
+
parts << "types=#{types}"
|
168
|
+
end
|
169
|
+
|
170
|
+
parts.join(" ") + ">"
|
156
171
|
end
|
157
172
|
|
158
173
|
alias to_s inspect
|
@@ -177,6 +192,14 @@ module Candle
|
|
177
192
|
def recognize(text, tokenizer = nil)
|
178
193
|
entities = []
|
179
194
|
|
195
|
+
# Limit text length to prevent ReDoS on very long strings
|
196
|
+
# This is especially important for Ruby < 3.2
|
197
|
+
max_length = 1_000_000 # 1MB of text
|
198
|
+
if text.length > max_length
|
199
|
+
warn "PatternEntityRecognizer: Text truncated from #{text.length} to #{max_length} chars for safety"
|
200
|
+
text = text[0...max_length]
|
201
|
+
end
|
202
|
+
|
180
203
|
@patterns.each do |pattern|
|
181
204
|
regex = pattern.is_a?(Regexp) ? pattern : Regexp.new(pattern)
|
182
205
|
|
@@ -186,12 +209,12 @@ module Candle
|
|
186
209
|
match_end = $~.offset(0)[1]
|
187
210
|
|
188
211
|
entities << {
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
212
|
+
text: match_text,
|
213
|
+
label: @entity_type,
|
214
|
+
start: match_start,
|
215
|
+
end: match_end,
|
216
|
+
confidence: 1.0,
|
217
|
+
source: "pattern"
|
195
218
|
}
|
196
219
|
end
|
197
220
|
end
|
@@ -242,12 +265,12 @@ module Candle
|
|
242
265
|
|
243
266
|
if word_boundary?(prev_char) && word_boundary?(next_char)
|
244
267
|
entities << {
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
268
|
+
text: text[idx, pattern.length],
|
269
|
+
label: @entity_type,
|
270
|
+
start: idx,
|
271
|
+
end: idx + pattern.length,
|
272
|
+
confidence: 1.0,
|
273
|
+
source: "gazetteer"
|
251
274
|
}
|
252
275
|
end
|
253
276
|
|
@@ -327,19 +350,19 @@ module Candle
|
|
327
350
|
|
328
351
|
def merge_entities(entities)
|
329
352
|
# Sort by start position and confidence (descending)
|
330
|
-
sorted = entities.sort_by { |e| [e[
|
353
|
+
sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
|
331
354
|
|
332
355
|
merged = []
|
333
356
|
sorted.each do |entity|
|
334
357
|
# Check if entity overlaps with any already merged
|
335
358
|
overlaps = merged.any? do |existing|
|
336
|
-
entity[
|
359
|
+
entity[:start] < existing[:end] && entity[:end] > existing[:start]
|
337
360
|
end
|
338
361
|
|
339
362
|
merged << entity unless overlaps
|
340
363
|
end
|
341
364
|
|
342
|
-
merged.sort_by { |e| e[
|
365
|
+
merged.sort_by { |e| e[:start] }
|
343
366
|
end
|
344
367
|
end
|
345
368
|
end
|
data/lib/candle/reranker.rb
CHANGED
@@ -3,10 +3,20 @@ module Candle
|
|
3
3
|
# Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
|
4
4
|
DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
5
5
|
|
6
|
+
# Load a pre-trained reranker model from HuggingFace
|
7
|
+
# @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
|
8
|
+
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
9
|
+
# @return [Reranker] A new Reranker instance
|
10
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
|
11
|
+
_create(model_id, device)
|
12
|
+
end
|
13
|
+
|
6
14
|
# Constructor for creating a new Reranker with optional parameters
|
15
|
+
# @deprecated Use {.from_pretrained} instead
|
7
16
|
# @param model_path [String, nil] The path to the model on Hugging Face
|
8
17
|
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
9
|
-
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.
|
18
|
+
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best)
|
19
|
+
$stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
|
10
20
|
_create(model_path, device)
|
11
21
|
end
|
12
22
|
|
@@ -20,5 +30,14 @@ module Candle
|
|
20
30
|
{ doc_id: doc_id, score: score, text: doc }
|
21
31
|
}
|
22
32
|
end
|
33
|
+
|
34
|
+
# Improved inspect method
|
35
|
+
def inspect
|
36
|
+
opts = options rescue {}
|
37
|
+
parts = ["#<Candle::Reranker"]
|
38
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
39
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
40
|
+
parts.join(" ") + ">"
|
41
|
+
end
|
23
42
|
end
|
24
43
|
end
|
data/lib/candle/tensor.rb
CHANGED
@@ -51,6 +51,21 @@ module Candle
|
|
51
51
|
to_f.to_i
|
52
52
|
end
|
53
53
|
|
54
|
+
# Improved inspect method showing shape, dtype, and device
|
55
|
+
def inspect
|
56
|
+
shape_str = shape.join("x")
|
57
|
+
|
58
|
+
parts = ["#<Candle::Tensor"]
|
59
|
+
parts << "shape=#{shape_str}"
|
60
|
+
parts << "dtype=#{dtype}"
|
61
|
+
parts << "device=#{device}"
|
62
|
+
|
63
|
+
# Add element count for clarity
|
64
|
+
parts << "elements=#{elem_count}"
|
65
|
+
|
66
|
+
parts.join(" ") + ">"
|
67
|
+
end
|
68
|
+
|
54
69
|
|
55
70
|
# Override class methods to support keyword arguments for device
|
56
71
|
class << self
|
data/lib/candle/version.rb
CHANGED