red-candle 1.1.2 → 1.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ use crate::ruby::{
6
6
  utils::{actual_dim, actual_index},
7
7
  };
8
8
  use crate::ruby::{DType, Device, Result};
9
- use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
9
+ use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
10
10
 
11
11
  #[derive(Clone, Debug)]
12
12
  #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
@@ -21,30 +21,108 @@ impl std::ops::Deref for Tensor {
21
21
  }
22
22
  }
23
23
 
24
+ // Helper functions for tensor operations
25
+ impl Tensor {
26
+ /// Check if device is Metal
27
+ fn is_metal_device(device: &CoreDevice) -> bool {
28
+ matches!(device.location(), DeviceLocation::Metal { .. })
29
+ }
30
+
31
+ /// Convert tensor to target dtype, handling Metal limitations
32
+ fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
33
+ if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
34
+ // Move to CPU first to avoid Metal conversion limitations
35
+ self.0
36
+ .to_device(&CoreDevice::Cpu)
37
+ .map_err(wrap_candle_err)?
38
+ .to_dtype(target_dtype)
39
+ .map_err(wrap_candle_err)
40
+ } else {
41
+ // Direct conversion for CPU or when dtype matches
42
+ self.0
43
+ .to_dtype(target_dtype)
44
+ .map_err(wrap_candle_err)
45
+ }
46
+ }
47
+ }
48
+
24
49
  impl Tensor {
25
50
  pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
26
51
  let dtype = dtype
27
52
  .map(|dtype| DType::from_rbobject(dtype))
28
53
  .unwrap_or(Ok(DType(CoreDType::F32)))?;
29
- let device = device.unwrap_or(Device::Cpu).as_device()?;
30
- // FIXME: Do not use `to_f64` here.
31
- let array = array
32
- .into_iter()
33
- .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
34
- .collect::<Result<Vec<_>>>()?;
35
- Ok(Self(
36
- CoreTensor::new(array.as_slice(), &device)
37
- .map_err(wrap_candle_err)?
38
- .to_dtype(dtype.0)
39
- .map_err(wrap_candle_err)?,
40
- ))
54
+ let device = device.unwrap_or(Device::best()).as_device()?;
55
+
56
+ // Create tensor based on target dtype to avoid conversion issues on Metal
57
+ let tensor = match dtype.0 {
58
+ CoreDType::F32 => {
59
+ // Convert to f32 directly to avoid F64->F32 conversion on Metal
60
+ let array: Vec<f32> = array
61
+ .into_iter()
62
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
63
+ .collect::<Result<Vec<_>>>()?;
64
+ let len = array.len();
65
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
66
+ }
67
+ CoreDType::F64 => {
68
+ let array: Vec<f64> = array
69
+ .into_iter()
70
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
71
+ .collect::<Result<Vec<_>>>()?;
72
+ let len = array.len();
73
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
74
+ }
75
+ CoreDType::I64 => {
76
+ // Convert to i64 directly to avoid conversion issues on Metal
77
+ let array: Vec<i64> = array
78
+ .into_iter()
79
+ .map(|v| {
80
+ // Try integer first, then float
81
+ if let Ok(i) = <i64>::try_convert(v) {
82
+ Ok(i)
83
+ } else if let Ok(f) = magnus::Float::try_convert(v) {
84
+ Ok(f.to_f64() as i64)
85
+ } else {
86
+ Err(magnus::Error::new(
87
+ magnus::exception::type_error(),
88
+ "Cannot convert to i64"
89
+ ))
90
+ }
91
+ })
92
+ .collect::<Result<Vec<_>>>()?;
93
+ let len = array.len();
94
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
95
+ }
96
+ _ => {
97
+ // For other dtypes, create on CPU first if on Metal, then convert
98
+ let cpu_device = CoreDevice::Cpu;
99
+ let use_cpu = Self::is_metal_device(&device);
100
+ let target_device = if use_cpu { &cpu_device } else { &device };
101
+
102
+ let array: Vec<f64> = array
103
+ .into_iter()
104
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
105
+ .collect::<Result<Vec<_>>>()?;
106
+ let tensor = CoreTensor::new(array.as_slice(), target_device)
107
+ .map_err(wrap_candle_err)?
108
+ .to_dtype(dtype.0)
109
+ .map_err(wrap_candle_err)?;
110
+
111
+ // Move to target device if needed
112
+ if use_cpu {
113
+ tensor.to_device(&device).map_err(wrap_candle_err)?
114
+ } else {
115
+ tensor
116
+ }
117
+ }
118
+ };
119
+
120
+ Ok(Self(tensor))
41
121
  }
42
122
 
43
123
  pub fn values(&self) -> Result<Vec<f64>> {
44
- let values = self
45
- .0
46
- .to_dtype(CoreDType::F64)
47
- .map_err(wrap_candle_err)?
124
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
125
+ let values = tensor
48
126
  .flatten_all()
49
127
  .map_err(wrap_candle_err)?
50
128
  .to_vec1()
@@ -92,11 +170,8 @@ impl Tensor {
92
170
  }
93
171
  _ => {
94
172
  // For other dtypes, convert to F64 first
95
- let val: f64 = self.0
96
- .to_dtype(CoreDType::F64)
97
- .map_err(wrap_candle_err)?
98
- .to_vec0()
99
- .map_err(wrap_candle_err)?;
173
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
174
+ let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
100
175
  Ok(val)
101
176
  }
102
177
  }
@@ -541,7 +616,7 @@ impl Tensor {
541
616
  /// Creates a new tensor with random values.
542
617
  /// &RETURNS&: Tensor
543
618
  pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
544
- let device = device.unwrap_or(Device::Cpu).as_device()?;
619
+ let device = device.unwrap_or(Device::best()).as_device()?;
545
620
  Ok(Self(
546
621
  CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
547
622
  ))
@@ -550,7 +625,7 @@ impl Tensor {
550
625
  /// Creates a new tensor with random values from a normal distribution.
551
626
  /// &RETURNS&: Tensor
552
627
  pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
553
- let device = device.unwrap_or(Device::Cpu).as_device()?;
628
+ let device = device.unwrap_or(Device::best()).as_device()?;
554
629
  Ok(Self(
555
630
  CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
556
631
  ))
@@ -559,7 +634,7 @@ impl Tensor {
559
634
  /// Creates a new tensor filled with ones.
560
635
  /// &RETURNS&: Tensor
561
636
  pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
562
- let device = device.unwrap_or(Device::Cpu).as_device()?;
637
+ let device = device.unwrap_or(Device::best()).as_device()?;
563
638
  Ok(Self(
564
639
  CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
565
640
  ))
@@ -567,7 +642,7 @@ impl Tensor {
567
642
  /// Creates a new tensor filled with zeros.
568
643
  /// &RETURNS&: Tensor
569
644
  pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
570
- let device = device.unwrap_or(Device::Cpu).as_device()?;
645
+ let device = device.unwrap_or(Device::best()).as_device()?;
571
646
  Ok(Self(
572
647
  CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
573
648
  ))
@@ -105,8 +105,8 @@ impl Tokenizer {
105
105
  }
106
106
 
107
107
  let hash = RHash::new();
108
- hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
- hash.aset("tokens", RArray::from_vec(tokens))?;
108
+ hash.aset(magnus::Symbol::new("ids"), RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
+ hash.aset(magnus::Symbol::new("tokens"), RArray::from_vec(tokens))?;
110
110
 
111
111
  Ok(hash)
112
112
  }
@@ -236,9 +236,65 @@ impl Tokenizer {
236
236
  Ok(hash)
237
237
  }
238
238
 
239
+ /// Get tokenizer options as a hash
240
+ pub fn options(&self) -> Result<RHash> {
241
+ let hash = RHash::new();
242
+
243
+ // Get vocab size
244
+ hash.aset("vocab_size", self.vocab_size(Some(true)))?;
245
+ hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
246
+
247
+ // Get special tokens info
248
+ let special_tokens = self.get_special_tokens()?;
249
+ hash.aset("special_tokens", special_tokens)?;
250
+
251
+ // Get padding/truncation info if available
252
+ let inner_tokenizer = self.0.inner();
253
+
254
+ // Check if padding is enabled
255
+ if let Some(_padding) = inner_tokenizer.get_padding() {
256
+ let padding_info = RHash::new();
257
+ padding_info.aset("enabled", true)?;
258
+ // Note: We can't easily extract all padding params from the tokenizers library
259
+ // but we can indicate it's enabled
260
+ hash.aset("padding", padding_info)?;
261
+ }
262
+
263
+ // Check if truncation is enabled
264
+ if let Some(truncation) = inner_tokenizer.get_truncation() {
265
+ let truncation_info = RHash::new();
266
+ truncation_info.aset("enabled", true)?;
267
+ truncation_info.aset("max_length", truncation.max_length)?;
268
+ hash.aset("truncation", truncation_info)?;
269
+ }
270
+
271
+ Ok(hash)
272
+ }
273
+
239
274
  /// String representation
240
275
  pub fn inspect(&self) -> String {
241
- format!("#<Candle::Tokenizer vocab_size={}>", self.vocab_size(Some(true)))
276
+ let vocab_size = self.vocab_size(Some(true));
277
+ let special_tokens = self.get_special_tokens()
278
+ .ok()
279
+ .map(|h| h.len())
280
+ .unwrap_or(0);
281
+
282
+ let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
283
+
284
+ if special_tokens > 0 {
285
+ parts.push(format!("special_tokens={}", special_tokens));
286
+ }
287
+
288
+ // Check for padding/truncation
289
+ let inner_tokenizer = self.0.inner();
290
+ if inner_tokenizer.get_padding().is_some() {
291
+ parts.push("padding=enabled".to_string());
292
+ }
293
+ if let Some(truncation) = inner_tokenizer.get_truncation() {
294
+ parts.push(format!("truncation={}", truncation.max_length));
295
+ }
296
+
297
+ parts.join(" ") + ">"
242
298
  }
243
299
  }
244
300
 
@@ -262,6 +318,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
262
318
  tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
263
319
  tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
264
320
  tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
321
+ tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
265
322
  tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
266
323
  tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
267
324
 
@@ -1,22 +1,10 @@
1
1
  module Candle
2
2
  module DeviceUtils
3
+ # @deprecated Use {Candle::Device.best} instead
3
4
  # Get the best available device (Metal > CUDA > CPU)
4
5
  def self.best_device
5
- # Try devices in order of preference
6
- begin
7
- # Try Metal first (for Mac users)
8
- Device.metal
9
- rescue
10
- # :nocov:
11
- begin
12
- # Try CUDA next (for NVIDIA GPU users)
13
- Device.cuda
14
- rescue
15
- # Fall back to CPU
16
- Device.cpu
17
- end
18
- # :nocov:
19
- end
6
+ warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
7
+ Device.best
20
8
  end
21
9
  end
22
10
  end
@@ -9,7 +9,36 @@ module Candle
9
9
  # Default embedding model type
10
10
  DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
11
11
 
12
+ # Load a pre-trained embedding model from HuggingFace
13
+ # @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
14
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
15
+ # @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
16
+ # @param model_type [String, nil] The type of embedding model (auto-detected if nil)
17
+ # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
+ # @return [EmbeddingModel] A new EmbeddingModel instance
19
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
20
+ # Auto-detect model type based on model_id if not provided
21
+ if model_type.nil?
22
+ model_type = case model_id.downcase
23
+ when /jina/
24
+ "jina_bert"
25
+ when /distilbert/
26
+ "distilbert"
27
+ when /minilm/
28
+ "minilm"
29
+ else
30
+ "standard_bert"
31
+ end
32
+ end
33
+
34
+ # Use model_id as tokenizer if not specified (usually what you want)
35
+ tokenizer_id = tokenizer || model_id
36
+
37
+ _create(model_id, tokenizer_id, device, model_type, embedding_size)
38
+ end
39
+
12
40
  # Constructor for creating a new EmbeddingModel with optional parameters
41
+ # @deprecated Use {.from_pretrained} instead
13
42
  # @param model_path [String, nil] The path to the model on Hugging Face
14
43
  # @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
15
44
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
@@ -17,9 +46,10 @@ module Candle
17
46
  # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
47
  def self.new(model_path: DEFAULT_MODEL_PATH,
19
48
  tokenizer_path: DEFAULT_TOKENIZER_PATH,
20
- device: Candle::Device.cpu,
49
+ device: Candle::Device.best,
21
50
  model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
22
51
  embedding_size: nil)
52
+ $stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
23
53
  _create(model_path, tokenizer_path, device, model_type, embedding_size)
24
54
  end
25
55
  # Returns the embedding for a string using the specified pooling method.
@@ -28,5 +58,18 @@ module Candle
28
58
  def embedding(str, pooling_method: "pooled_normalized")
29
59
  _embedding(str, pooling_method)
30
60
  end
61
+
62
+ # Improved inspect method
63
+ def inspect
64
+ opts = options rescue {}
65
+
66
+ parts = ["#<Candle::EmbeddingModel"]
67
+ parts << "model=#{opts["model_id"] || "unknown"}"
68
+ parts << "type=#{opts["model_type"]}" if opts["model_type"]
69
+ parts << "device=#{opts["device"] || "unknown"}"
70
+ parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
71
+
72
+ parts.join(" ") + ">"
73
+ end
31
74
  end
32
75
  end
data/lib/candle/llm.rb CHANGED
@@ -189,6 +189,45 @@ module Candle
189
189
  prompt = apply_chat_template(messages)
190
190
  generate_stream(prompt, **options, &block)
191
191
  end
192
+
193
+ # Inspect method for debugging and exploration
194
+ def inspect
195
+ opts = options rescue {}
196
+
197
+ # Extract key information
198
+ model_type = opts["model_type"] || "Unknown"
199
+ device = opts["device"] || self.device.to_s rescue "unknown"
200
+
201
+ # Build the inspect string
202
+ parts = ["#<Candle::LLM"]
203
+
204
+ # Add base model or model_id
205
+ if opts["base_model"]
206
+ parts << "model=#{opts["base_model"]}"
207
+ elsif opts["model_id"]
208
+ parts << "model=#{opts["model_id"]}"
209
+ elsif respond_to?(:model_id)
210
+ parts << "model=#{model_id}"
211
+ end
212
+
213
+ # Add GGUF file if present
214
+ if opts["gguf_file"]
215
+ parts << "gguf=#{opts["gguf_file"]}"
216
+ end
217
+
218
+ # Add device
219
+ parts << "device=#{device}"
220
+
221
+ # Add model type
222
+ parts << "type=#{model_type}"
223
+
224
+ # Add architecture for GGUF models
225
+ if opts["architecture"]
226
+ parts << "arch=#{opts["architecture"]}"
227
+ end
228
+
229
+ parts.join(" ") + ">"
230
+ end
192
231
 
193
232
  def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
194
233
  begin
@@ -206,7 +245,7 @@ module Candle
206
245
  end
207
246
  end
208
247
 
209
- def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
248
+ def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
210
249
  model_str = if gguf_file
211
250
  "#{model_id}@#{gguf_file}"
212
251
  else
@@ -393,5 +432,28 @@ module Candle
393
432
  }
394
433
  new(defaults.merge(opts))
395
434
  end
435
+
436
+ # Inspect method for debugging and exploration
437
+ def inspect
438
+ opts = options rescue {}
439
+
440
+ parts = ["#<Candle::GenerationConfig"]
441
+
442
+ # Add key configuration parameters
443
+ parts << "temp=#{opts["temperature"]}" if opts["temperature"]
444
+ parts << "max=#{opts["max_length"]}" if opts["max_length"]
445
+ parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
446
+ parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
447
+ parts << "seed=#{opts["seed"]}" if opts["seed"]
448
+
449
+ # Add flags
450
+ flags = []
451
+ flags << "debug" if opts["debug_tokens"]
452
+ flags << "constraint" if opts["has_constraint"]
453
+ flags << "stop_on_match" if opts["stop_on_match"]
454
+ parts << "flags=[#{flags.join(",")}]" if flags.any?
455
+
456
+ parts.join(" ") + ">"
457
+ end
396
458
  end
397
459
  end
data/lib/candle/ner.rb CHANGED
@@ -1,5 +1,8 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ # Pattern validation available but not forced
4
+ # require_relative 'pattern_validator' # Uncomment if needed
5
+
3
6
  module Candle
4
7
  # Named Entity Recognition (NER) for token classification
5
8
  #
@@ -30,10 +33,10 @@ module Candle
30
33
  # Load a pre-trained NER model from HuggingFace
31
34
  #
32
35
  # @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
33
- # @param device [Device, nil] Device to run on (defaults to best available)
36
+ # @param device [Device] Device to run on (defaults to best available)
34
37
  # @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
35
38
  # @return [NER] NER instance
36
- def from_pretrained(model_id, device: nil, tokenizer: nil)
39
+ def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
37
40
  new(model_id, device, tokenizer)
38
41
  end
39
42
 
@@ -112,7 +115,7 @@ module Candle
112
115
  # @return [Array<Hash>] Filtered entities of the specified type
113
116
  def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
114
117
  entities = extract_entities(text, confidence_threshold: confidence_threshold)
115
- entities.select { |e| e["label"] == entity_type.upcase }
118
+ entities.select { |e| e[:label] == entity_type.upcase }
116
119
  end
117
120
 
118
121
  # Analyze text and return both entities and token predictions
@@ -137,12 +140,12 @@ module Candle
137
140
  return text if entities.empty?
138
141
 
139
142
  # Sort by start position (reverse for easier insertion)
140
- entities.sort_by! { |e| -e["start"] }
143
+ entities.sort_by! { |e| -e[:start] }
141
144
 
142
145
  result = text.dup
143
146
  entities.each do |entity|
144
- label = "[#{entity['label']}:#{entity['confidence'].round(2)}]"
145
- result.insert(entity["end"], label)
147
+ label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
148
+ result.insert(entity[:end], label)
146
149
  end
147
150
 
148
151
  result
@@ -152,7 +155,19 @@ module Candle
152
155
  #
153
156
  # @return [String] Model description
154
157
  def inspect
155
- "#<Candle::NER #{model_info}>"
158
+ opts = options rescue {}
159
+
160
+ parts = ["#<Candle::NER"]
161
+ parts << "model=#{opts["model_id"] || "unknown"}"
162
+ parts << "device=#{opts["device"] || "unknown"}"
163
+ parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
164
+
165
+ if opts["entity_types"] && !opts["entity_types"].empty?
166
+ types = opts["entity_types"].sort.join(",")
167
+ parts << "types=#{types}"
168
+ end
169
+
170
+ parts.join(" ") + ">"
156
171
  end
157
172
 
158
173
  alias to_s inspect
@@ -177,6 +192,14 @@ module Candle
177
192
  def recognize(text, tokenizer = nil)
178
193
  entities = []
179
194
 
195
+ # Limit text length to prevent ReDoS on very long strings
196
+ # This is especially important for Ruby < 3.2
197
+ max_length = 1_000_000 # 1MB of text
198
+ if text.length > max_length
199
+ warn "PatternEntityRecognizer: Text truncated from #{text.length} to #{max_length} chars for safety"
200
+ text = text[0...max_length]
201
+ end
202
+
180
203
  @patterns.each do |pattern|
181
204
  regex = pattern.is_a?(Regexp) ? pattern : Regexp.new(pattern)
182
205
 
@@ -186,12 +209,12 @@ module Candle
186
209
  match_end = $~.offset(0)[1]
187
210
 
188
211
  entities << {
189
- "text" => match_text,
190
- "label" => @entity_type,
191
- "start" => match_start,
192
- "end" => match_end,
193
- "confidence" => 1.0,
194
- "source" => "pattern"
212
+ text: match_text,
213
+ label: @entity_type,
214
+ start: match_start,
215
+ end: match_end,
216
+ confidence: 1.0,
217
+ source: "pattern"
195
218
  }
196
219
  end
197
220
  end
@@ -242,12 +265,12 @@ module Candle
242
265
 
243
266
  if word_boundary?(prev_char) && word_boundary?(next_char)
244
267
  entities << {
245
- "text" => text[idx, pattern.length],
246
- "label" => @entity_type,
247
- "start" => idx,
248
- "end" => idx + pattern.length,
249
- "confidence" => 1.0,
250
- "source" => "gazetteer"
268
+ text: text[idx, pattern.length],
269
+ label: @entity_type,
270
+ start: idx,
271
+ end: idx + pattern.length,
272
+ confidence: 1.0,
273
+ source: "gazetteer"
251
274
  }
252
275
  end
253
276
 
@@ -327,19 +350,19 @@ module Candle
327
350
 
328
351
  def merge_entities(entities)
329
352
  # Sort by start position and confidence (descending)
330
- sorted = entities.sort_by { |e| [e["start"], -e["confidence"]] }
353
+ sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
331
354
 
332
355
  merged = []
333
356
  sorted.each do |entity|
334
357
  # Check if entity overlaps with any already merged
335
358
  overlaps = merged.any? do |existing|
336
- entity["start"] < existing["end"] && entity["end"] > existing["start"]
359
+ entity[:start] < existing[:end] && entity[:end] > existing[:start]
337
360
  end
338
361
 
339
362
  merged << entity unless overlaps
340
363
  end
341
364
 
342
- merged.sort_by { |e| e["start"] }
365
+ merged.sort_by { |e| e[:start] }
343
366
  end
344
367
  end
345
368
  end
@@ -3,10 +3,20 @@ module Candle
3
3
  # Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
4
4
  DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
5
5
 
6
+ # Load a pre-trained reranker model from HuggingFace
7
+ # @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
8
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
9
+ # @return [Reranker] A new Reranker instance
10
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
11
+ _create(model_id, device)
12
+ end
13
+
6
14
  # Constructor for creating a new Reranker with optional parameters
15
+ # @deprecated Use {.from_pretrained} instead
7
16
  # @param model_path [String, nil] The path to the model on Hugging Face
8
17
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
9
- def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.cpu)
18
+ def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best)
19
+ $stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
10
20
  _create(model_path, device)
11
21
  end
12
22
 
@@ -20,5 +30,14 @@ module Candle
20
30
  { doc_id: doc_id, score: score, text: doc }
21
31
  }
22
32
  end
33
+
34
+ # Improved inspect method
35
+ def inspect
36
+ opts = options rescue {}
37
+ parts = ["#<Candle::Reranker"]
38
+ parts << "model=#{opts["model_id"] || "unknown"}"
39
+ parts << "device=#{opts["device"] || "unknown"}"
40
+ parts.join(" ") + ">"
41
+ end
23
42
  end
24
43
  end
data/lib/candle/tensor.rb CHANGED
@@ -51,6 +51,21 @@ module Candle
51
51
  to_f.to_i
52
52
  end
53
53
 
54
+ # Improved inspect method showing shape, dtype, and device
55
+ def inspect
56
+ shape_str = shape.join("x")
57
+
58
+ parts = ["#<Candle::Tensor"]
59
+ parts << "shape=#{shape_str}"
60
+ parts << "dtype=#{dtype}"
61
+ parts << "device=#{device}"
62
+
63
+ # Add element count for clarity
64
+ parts << "elements=#{elem_count}"
65
+
66
+ parts.join(" ") + ">"
67
+ end
68
+
54
69
 
55
70
  # Override class methods to support keyword arguments for device
56
71
  class << self
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.1.2"
3
+ VERSION = "1.2.1"
4
4
  end
5
5
  # :nocov: