red-candle 1.1.2 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,7 +6,7 @@ use crate::ruby::{
6
6
  utils::{actual_dim, actual_index},
7
7
  };
8
8
  use crate::ruby::{DType, Device, Result};
9
- use ::candle_core::{DType as CoreDType, Tensor as CoreTensor};
9
+ use ::candle_core::{DType as CoreDType, Tensor as CoreTensor, Device as CoreDevice, DeviceLocation};
10
10
 
11
11
  #[derive(Clone, Debug)]
12
12
  #[magnus::wrap(class = "Candle::Tensor", free_immediately, size)]
@@ -21,30 +21,108 @@ impl std::ops::Deref for Tensor {
21
21
  }
22
22
  }
23
23
 
24
+ // Helper functions for tensor operations
25
+ impl Tensor {
26
+ /// Check if device is Metal
27
+ fn is_metal_device(device: &CoreDevice) -> bool {
28
+ matches!(device.location(), DeviceLocation::Metal { .. })
29
+ }
30
+
31
+ /// Convert tensor to target dtype, handling Metal limitations
32
+ fn safe_to_dtype(&self, target_dtype: CoreDType) -> Result<CoreTensor> {
33
+ if Self::is_metal_device(self.0.device()) && self.0.dtype() != target_dtype {
34
+ // Move to CPU first to avoid Metal conversion limitations
35
+ self.0
36
+ .to_device(&CoreDevice::Cpu)
37
+ .map_err(wrap_candle_err)?
38
+ .to_dtype(target_dtype)
39
+ .map_err(wrap_candle_err)
40
+ } else {
41
+ // Direct conversion for CPU or when dtype matches
42
+ self.0
43
+ .to_dtype(target_dtype)
44
+ .map_err(wrap_candle_err)
45
+ }
46
+ }
47
+ }
48
+
24
49
  impl Tensor {
25
50
  pub fn new(array: magnus::RArray, dtype: Option<magnus::Symbol>, device: Option<Device>) -> Result<Self> {
26
51
  let dtype = dtype
27
52
  .map(|dtype| DType::from_rbobject(dtype))
28
53
  .unwrap_or(Ok(DType(CoreDType::F32)))?;
29
- let device = device.unwrap_or(Device::Cpu).as_device()?;
30
- // FIXME: Do not use `to_f64` here.
31
- let array = array
32
- .into_iter()
33
- .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
34
- .collect::<Result<Vec<_>>>()?;
35
- Ok(Self(
36
- CoreTensor::new(array.as_slice(), &device)
37
- .map_err(wrap_candle_err)?
38
- .to_dtype(dtype.0)
39
- .map_err(wrap_candle_err)?,
40
- ))
54
+ let device = device.unwrap_or(Device::best()).as_device()?;
55
+
56
+ // Create tensor based on target dtype to avoid conversion issues on Metal
57
+ let tensor = match dtype.0 {
58
+ CoreDType::F32 => {
59
+ // Convert to f32 directly to avoid F64->F32 conversion on Metal
60
+ let array: Vec<f32> = array
61
+ .into_iter()
62
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64() as f32))
63
+ .collect::<Result<Vec<_>>>()?;
64
+ let len = array.len();
65
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
66
+ }
67
+ CoreDType::F64 => {
68
+ let array: Vec<f64> = array
69
+ .into_iter()
70
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
71
+ .collect::<Result<Vec<_>>>()?;
72
+ let len = array.len();
73
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
74
+ }
75
+ CoreDType::I64 => {
76
+ // Convert to i64 directly to avoid conversion issues on Metal
77
+ let array: Vec<i64> = array
78
+ .into_iter()
79
+ .map(|v| {
80
+ // Try integer first, then float
81
+ if let Ok(i) = <i64>::try_convert(v) {
82
+ Ok(i)
83
+ } else if let Ok(f) = magnus::Float::try_convert(v) {
84
+ Ok(f.to_f64() as i64)
85
+ } else {
86
+ Err(magnus::Error::new(
87
+ magnus::exception::type_error(),
88
+ "Cannot convert to i64"
89
+ ))
90
+ }
91
+ })
92
+ .collect::<Result<Vec<_>>>()?;
93
+ let len = array.len();
94
+ CoreTensor::from_vec(array, len, &device).map_err(wrap_candle_err)?
95
+ }
96
+ _ => {
97
+ // For other dtypes, create on CPU first if on Metal, then convert
98
+ let cpu_device = CoreDevice::Cpu;
99
+ let use_cpu = Self::is_metal_device(&device);
100
+ let target_device = if use_cpu { &cpu_device } else { &device };
101
+
102
+ let array: Vec<f64> = array
103
+ .into_iter()
104
+ .map(|v| magnus::Float::try_convert(v).map(|v| v.to_f64()))
105
+ .collect::<Result<Vec<_>>>()?;
106
+ let tensor = CoreTensor::new(array.as_slice(), target_device)
107
+ .map_err(wrap_candle_err)?
108
+ .to_dtype(dtype.0)
109
+ .map_err(wrap_candle_err)?;
110
+
111
+ // Move to target device if needed
112
+ if use_cpu {
113
+ tensor.to_device(&device).map_err(wrap_candle_err)?
114
+ } else {
115
+ tensor
116
+ }
117
+ }
118
+ };
119
+
120
+ Ok(Self(tensor))
41
121
  }
42
122
 
43
123
  pub fn values(&self) -> Result<Vec<f64>> {
44
- let values = self
45
- .0
46
- .to_dtype(CoreDType::F64)
47
- .map_err(wrap_candle_err)?
124
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
125
+ let values = tensor
48
126
  .flatten_all()
49
127
  .map_err(wrap_candle_err)?
50
128
  .to_vec1()
@@ -92,11 +170,8 @@ impl Tensor {
92
170
  }
93
171
  _ => {
94
172
  // For other dtypes, convert to F64 first
95
- let val: f64 = self.0
96
- .to_dtype(CoreDType::F64)
97
- .map_err(wrap_candle_err)?
98
- .to_vec0()
99
- .map_err(wrap_candle_err)?;
173
+ let tensor = self.safe_to_dtype(CoreDType::F64)?;
174
+ let val: f64 = tensor.to_vec0().map_err(wrap_candle_err)?;
100
175
  Ok(val)
101
176
  }
102
177
  }
@@ -541,7 +616,7 @@ impl Tensor {
541
616
  /// Creates a new tensor with random values.
542
617
  /// &RETURNS&: Tensor
543
618
  pub fn rand(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
544
- let device = device.unwrap_or(Device::Cpu).as_device()?;
619
+ let device = device.unwrap_or(Device::best()).as_device()?;
545
620
  Ok(Self(
546
621
  CoreTensor::rand(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
547
622
  ))
@@ -550,7 +625,7 @@ impl Tensor {
550
625
  /// Creates a new tensor with random values from a normal distribution.
551
626
  /// &RETURNS&: Tensor
552
627
  pub fn randn(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
553
- let device = device.unwrap_or(Device::Cpu).as_device()?;
628
+ let device = device.unwrap_or(Device::best()).as_device()?;
554
629
  Ok(Self(
555
630
  CoreTensor::randn(0f32, 1f32, shape, &device).map_err(wrap_candle_err)?,
556
631
  ))
@@ -559,7 +634,7 @@ impl Tensor {
559
634
  /// Creates a new tensor filled with ones.
560
635
  /// &RETURNS&: Tensor
561
636
  pub fn ones(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
562
- let device = device.unwrap_or(Device::Cpu).as_device()?;
637
+ let device = device.unwrap_or(Device::best()).as_device()?;
563
638
  Ok(Self(
564
639
  CoreTensor::ones(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
565
640
  ))
@@ -567,7 +642,7 @@ impl Tensor {
567
642
  /// Creates a new tensor filled with zeros.
568
643
  /// &RETURNS&: Tensor
569
644
  pub fn zeros(shape: Vec<usize>, device: Option<Device>) -> Result<Self> {
570
- let device = device.unwrap_or(Device::Cpu).as_device()?;
645
+ let device = device.unwrap_or(Device::best()).as_device()?;
571
646
  Ok(Self(
572
647
  CoreTensor::zeros(shape, CoreDType::F32, &device).map_err(wrap_candle_err)?,
573
648
  ))
@@ -105,8 +105,8 @@ impl Tokenizer {
105
105
  }
106
106
 
107
107
  let hash = RHash::new();
108
- hash.aset("ids", RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
- hash.aset("tokens", RArray::from_vec(tokens))?;
108
+ hash.aset(magnus::Symbol::new("ids"), RArray::from_vec(token_ids.into_iter().map(|id| id as i64).collect()))?;
109
+ hash.aset(magnus::Symbol::new("tokens"), RArray::from_vec(tokens))?;
110
110
 
111
111
  Ok(hash)
112
112
  }
@@ -236,9 +236,65 @@ impl Tokenizer {
236
236
  Ok(hash)
237
237
  }
238
238
 
239
+ /// Get tokenizer options as a hash
240
+ pub fn options(&self) -> Result<RHash> {
241
+ let hash = RHash::new();
242
+
243
+ // Get vocab size
244
+ hash.aset("vocab_size", self.vocab_size(Some(true)))?;
245
+ hash.aset("vocab_size_base", self.vocab_size(Some(false)))?;
246
+
247
+ // Get special tokens info
248
+ let special_tokens = self.get_special_tokens()?;
249
+ hash.aset("special_tokens", special_tokens)?;
250
+
251
+ // Get padding/truncation info if available
252
+ let inner_tokenizer = self.0.inner();
253
+
254
+ // Check if padding is enabled
255
+ if let Some(_padding) = inner_tokenizer.get_padding() {
256
+ let padding_info = RHash::new();
257
+ padding_info.aset("enabled", true)?;
258
+ // Note: We can't easily extract all padding params from the tokenizers library
259
+ // but we can indicate it's enabled
260
+ hash.aset("padding", padding_info)?;
261
+ }
262
+
263
+ // Check if truncation is enabled
264
+ if let Some(truncation) = inner_tokenizer.get_truncation() {
265
+ let truncation_info = RHash::new();
266
+ truncation_info.aset("enabled", true)?;
267
+ truncation_info.aset("max_length", truncation.max_length)?;
268
+ hash.aset("truncation", truncation_info)?;
269
+ }
270
+
271
+ Ok(hash)
272
+ }
273
+
239
274
  /// String representation
240
275
  pub fn inspect(&self) -> String {
241
- format!("#<Candle::Tokenizer vocab_size={}>", self.vocab_size(Some(true)))
276
+ let vocab_size = self.vocab_size(Some(true));
277
+ let special_tokens = self.get_special_tokens()
278
+ .ok()
279
+ .map(|h| h.len())
280
+ .unwrap_or(0);
281
+
282
+ let mut parts = vec![format!("#<Candle::Tokenizer vocab_size={}", vocab_size)];
283
+
284
+ if special_tokens > 0 {
285
+ parts.push(format!("special_tokens={}", special_tokens));
286
+ }
287
+
288
+ // Check for padding/truncation
289
+ let inner_tokenizer = self.0.inner();
290
+ if inner_tokenizer.get_padding().is_some() {
291
+ parts.push("padding=enabled".to_string());
292
+ }
293
+ if let Some(truncation) = inner_tokenizer.get_truncation() {
294
+ parts.push(format!("truncation={}", truncation.max_length));
295
+ }
296
+
297
+ parts.join(" ") + ">"
242
298
  }
243
299
  }
244
300
 
@@ -262,6 +318,7 @@ pub fn init(rb_candle: RModule) -> Result<()> {
262
318
  tokenizer_class.define_method("with_padding", method!(Tokenizer::with_padding, 1))?;
263
319
  tokenizer_class.define_method("with_truncation", method!(Tokenizer::with_truncation, 1))?;
264
320
  tokenizer_class.define_method("get_special_tokens", method!(Tokenizer::get_special_tokens, 0))?;
321
+ tokenizer_class.define_method("options", method!(Tokenizer::options, 0))?;
265
322
  tokenizer_class.define_method("inspect", method!(Tokenizer::inspect, 0))?;
266
323
  tokenizer_class.define_method("to_s", method!(Tokenizer::inspect, 0))?;
267
324
 
@@ -1,22 +1,10 @@
1
1
  module Candle
2
2
  module DeviceUtils
3
+ # @deprecated Use {Candle::Device.best} instead
3
4
  # Get the best available device (Metal > CUDA > CPU)
4
5
  def self.best_device
5
- # Try devices in order of preference
6
- begin
7
- # Try Metal first (for Mac users)
8
- Device.metal
9
- rescue
10
- # :nocov:
11
- begin
12
- # Try CUDA next (for NVIDIA GPU users)
13
- Device.cuda
14
- rescue
15
- # Fall back to CPU
16
- Device.cpu
17
- end
18
- # :nocov:
19
- end
6
+ warn "[DEPRECATION] `DeviceUtils.best_device` is deprecated. Please use `Device.best` instead."
7
+ Device.best
20
8
  end
21
9
  end
22
10
  end
@@ -9,7 +9,36 @@ module Candle
9
9
  # Default embedding model type
10
10
  DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
11
11
 
12
+ # Load a pre-trained embedding model from HuggingFace
13
+ # @param model_id [String] HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)
14
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
15
+ # @param tokenizer [String, nil] The tokenizer to use (defaults to using the model's tokenizer)
16
+ # @param model_type [String, nil] The type of embedding model (auto-detected if nil)
17
+ # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
+ # @return [EmbeddingModel] A new EmbeddingModel instance
19
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
20
+ # Auto-detect model type based on model_id if not provided
21
+ if model_type.nil?
22
+ model_type = case model_id.downcase
23
+ when /jina/
24
+ "jina_bert"
25
+ when /distilbert/
26
+ "distilbert"
27
+ when /minilm/
28
+ "minilm"
29
+ else
30
+ "standard_bert"
31
+ end
32
+ end
33
+
34
+ # Use model_id as tokenizer if not specified (usually what you want)
35
+ tokenizer_id = tokenizer || model_id
36
+
37
+ _create(model_id, tokenizer_id, device, model_type, embedding_size)
38
+ end
39
+
12
40
  # Constructor for creating a new EmbeddingModel with optional parameters
41
+ # @deprecated Use {.from_pretrained} instead
13
42
  # @param model_path [String, nil] The path to the model on Hugging Face
14
43
  # @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
15
44
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
@@ -17,9 +46,10 @@ module Candle
17
46
  # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
47
  def self.new(model_path: DEFAULT_MODEL_PATH,
19
48
  tokenizer_path: DEFAULT_TOKENIZER_PATH,
20
- device: Candle::Device.cpu,
49
+ device: Candle::Device.best,
21
50
  model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
22
51
  embedding_size: nil)
52
+ $stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
23
53
  _create(model_path, tokenizer_path, device, model_type, embedding_size)
24
54
  end
25
55
  # Returns the embedding for a string using the specified pooling method.
@@ -28,5 +58,18 @@ module Candle
28
58
  def embedding(str, pooling_method: "pooled_normalized")
29
59
  _embedding(str, pooling_method)
30
60
  end
61
+
62
+ # Improved inspect method
63
+ def inspect
64
+ opts = options rescue {}
65
+
66
+ parts = ["#<Candle::EmbeddingModel"]
67
+ parts << "model=#{opts["model_id"] || "unknown"}"
68
+ parts << "type=#{opts["model_type"]}" if opts["model_type"]
69
+ parts << "device=#{opts["device"] || "unknown"}"
70
+ parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
71
+
72
+ parts.join(" ") + ">"
73
+ end
31
74
  end
32
75
  end
data/lib/candle/llm.rb CHANGED
@@ -189,6 +189,45 @@ module Candle
189
189
  prompt = apply_chat_template(messages)
190
190
  generate_stream(prompt, **options, &block)
191
191
  end
192
+
193
+ # Inspect method for debugging and exploration
194
+ def inspect
195
+ opts = options rescue {}
196
+
197
+ # Extract key information
198
+ model_type = opts["model_type"] || "Unknown"
199
+ device = opts["device"] || self.device.to_s rescue "unknown"
200
+
201
+ # Build the inspect string
202
+ parts = ["#<Candle::LLM"]
203
+
204
+ # Add base model or model_id
205
+ if opts["base_model"]
206
+ parts << "model=#{opts["base_model"]}"
207
+ elsif opts["model_id"]
208
+ parts << "model=#{opts["model_id"]}"
209
+ elsif respond_to?(:model_id)
210
+ parts << "model=#{model_id}"
211
+ end
212
+
213
+ # Add GGUF file if present
214
+ if opts["gguf_file"]
215
+ parts << "gguf=#{opts["gguf_file"]}"
216
+ end
217
+
218
+ # Add device
219
+ parts << "device=#{device}"
220
+
221
+ # Add model type
222
+ parts << "type=#{model_type}"
223
+
224
+ # Add architecture for GGUF models
225
+ if opts["architecture"]
226
+ parts << "arch=#{opts["architecture"]}"
227
+ end
228
+
229
+ parts.join(" ") + ">"
230
+ end
192
231
 
193
232
  def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
194
233
  begin
@@ -206,7 +245,7 @@ module Candle
206
245
  end
207
246
  end
208
247
 
209
- def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
248
+ def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
210
249
  model_str = if gguf_file
211
250
  "#{model_id}@#{gguf_file}"
212
251
  else
@@ -393,5 +432,28 @@ module Candle
393
432
  }
394
433
  new(defaults.merge(opts))
395
434
  end
435
+
436
+ # Inspect method for debugging and exploration
437
+ def inspect
438
+ opts = options rescue {}
439
+
440
+ parts = ["#<Candle::GenerationConfig"]
441
+
442
+ # Add key configuration parameters
443
+ parts << "temp=#{opts["temperature"]}" if opts["temperature"]
444
+ parts << "max=#{opts["max_length"]}" if opts["max_length"]
445
+ parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
446
+ parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
447
+ parts << "seed=#{opts["seed"]}" if opts["seed"]
448
+
449
+ # Add flags
450
+ flags = []
451
+ flags << "debug" if opts["debug_tokens"]
452
+ flags << "constraint" if opts["has_constraint"]
453
+ flags << "stop_on_match" if opts["stop_on_match"]
454
+ parts << "flags=[#{flags.join(",")}]" if flags.any?
455
+
456
+ parts.join(" ") + ">"
457
+ end
396
458
  end
397
459
  end
data/lib/candle/ner.rb CHANGED
@@ -30,10 +30,10 @@ module Candle
30
30
  # Load a pre-trained NER model from HuggingFace
31
31
  #
32
32
  # @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
33
- # @param device [Device, nil] Device to run on (defaults to best available)
33
+ # @param device [Device] Device to run on (defaults to best available)
34
34
  # @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
35
35
  # @return [NER] NER instance
36
- def from_pretrained(model_id, device: nil, tokenizer: nil)
36
+ def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
37
37
  new(model_id, device, tokenizer)
38
38
  end
39
39
 
@@ -112,7 +112,7 @@ module Candle
112
112
  # @return [Array<Hash>] Filtered entities of the specified type
113
113
  def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
114
114
  entities = extract_entities(text, confidence_threshold: confidence_threshold)
115
- entities.select { |e| e["label"] == entity_type.upcase }
115
+ entities.select { |e| e[:label] == entity_type.upcase }
116
116
  end
117
117
 
118
118
  # Analyze text and return both entities and token predictions
@@ -137,12 +137,12 @@ module Candle
137
137
  return text if entities.empty?
138
138
 
139
139
  # Sort by start position (reverse for easier insertion)
140
- entities.sort_by! { |e| -e["start"] }
140
+ entities.sort_by! { |e| -e[:start] }
141
141
 
142
142
  result = text.dup
143
143
  entities.each do |entity|
144
- label = "[#{entity['label']}:#{entity['confidence'].round(2)}]"
145
- result.insert(entity["end"], label)
144
+ label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
145
+ result.insert(entity[:end], label)
146
146
  end
147
147
 
148
148
  result
@@ -152,7 +152,19 @@ module Candle
152
152
  #
153
153
  # @return [String] Model description
154
154
  def inspect
155
- "#<Candle::NER #{model_info}>"
155
+ opts = options rescue {}
156
+
157
+ parts = ["#<Candle::NER"]
158
+ parts << "model=#{opts["model_id"] || "unknown"}"
159
+ parts << "device=#{opts["device"] || "unknown"}"
160
+ parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
161
+
162
+ if opts["entity_types"] && !opts["entity_types"].empty?
163
+ types = opts["entity_types"].sort.join(",")
164
+ parts << "types=#{types}"
165
+ end
166
+
167
+ parts.join(" ") + ">"
156
168
  end
157
169
 
158
170
  alias to_s inspect
@@ -186,12 +198,12 @@ module Candle
186
198
  match_end = $~.offset(0)[1]
187
199
 
188
200
  entities << {
189
- "text" => match_text,
190
- "label" => @entity_type,
191
- "start" => match_start,
192
- "end" => match_end,
193
- "confidence" => 1.0,
194
- "source" => "pattern"
201
+ text: match_text,
202
+ label: @entity_type,
203
+ start: match_start,
204
+ end: match_end,
205
+ confidence: 1.0,
206
+ source: "pattern"
195
207
  }
196
208
  end
197
209
  end
@@ -242,12 +254,12 @@ module Candle
242
254
 
243
255
  if word_boundary?(prev_char) && word_boundary?(next_char)
244
256
  entities << {
245
- "text" => text[idx, pattern.length],
246
- "label" => @entity_type,
247
- "start" => idx,
248
- "end" => idx + pattern.length,
249
- "confidence" => 1.0,
250
- "source" => "gazetteer"
257
+ text: text[idx, pattern.length],
258
+ label: @entity_type,
259
+ start: idx,
260
+ end: idx + pattern.length,
261
+ confidence: 1.0,
262
+ source: "gazetteer"
251
263
  }
252
264
  end
253
265
 
@@ -327,19 +339,19 @@ module Candle
327
339
 
328
340
  def merge_entities(entities)
329
341
  # Sort by start position and confidence (descending)
330
- sorted = entities.sort_by { |e| [e["start"], -e["confidence"]] }
342
+ sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
331
343
 
332
344
  merged = []
333
345
  sorted.each do |entity|
334
346
  # Check if entity overlaps with any already merged
335
347
  overlaps = merged.any? do |existing|
336
- entity["start"] < existing["end"] && entity["end"] > existing["start"]
348
+ entity[:start] < existing[:end] && entity[:end] > existing[:start]
337
349
  end
338
350
 
339
351
  merged << entity unless overlaps
340
352
  end
341
353
 
342
- merged.sort_by { |e| e["start"] }
354
+ merged.sort_by { |e| e[:start] }
343
355
  end
344
356
  end
345
357
  end
@@ -3,10 +3,20 @@ module Candle
3
3
  # Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
4
4
  DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
5
5
 
6
+ # Load a pre-trained reranker model from HuggingFace
7
+ # @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
8
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
9
+ # @return [Reranker] A new Reranker instance
10
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
11
+ _create(model_id, device)
12
+ end
13
+
6
14
  # Constructor for creating a new Reranker with optional parameters
15
+ # @deprecated Use {.from_pretrained} instead
7
16
  # @param model_path [String, nil] The path to the model on Hugging Face
8
17
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
9
- def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.cpu)
18
+ def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best)
19
+ $stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
10
20
  _create(model_path, device)
11
21
  end
12
22
 
@@ -20,5 +30,14 @@ module Candle
20
30
  { doc_id: doc_id, score: score, text: doc }
21
31
  }
22
32
  end
33
+
34
+ # Improved inspect method
35
+ def inspect
36
+ opts = options rescue {}
37
+ parts = ["#<Candle::Reranker"]
38
+ parts << "model=#{opts["model_id"] || "unknown"}"
39
+ parts << "device=#{opts["device"] || "unknown"}"
40
+ parts.join(" ") + ">"
41
+ end
23
42
  end
24
43
  end
data/lib/candle/tensor.rb CHANGED
@@ -51,6 +51,21 @@ module Candle
51
51
  to_f.to_i
52
52
  end
53
53
 
54
+ # Improved inspect method showing shape, dtype, and device
55
+ def inspect
56
+ shape_str = shape.join("x")
57
+
58
+ parts = ["#<Candle::Tensor"]
59
+ parts << "shape=#{shape_str}"
60
+ parts << "dtype=#{dtype}"
61
+ parts << "device=#{device}"
62
+
63
+ # Add element count for clarity
64
+ parts << "elements=#{elem_count}"
65
+
66
+ parts.join(" ") + ">"
67
+ end
68
+
54
69
 
55
70
  # Override class methods to support keyword arguments for device
56
71
  class << self
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.1.2"
3
+ VERSION = "1.2.0"
4
4
  end
5
5
  # :nocov: