red-candle 1.1.1 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/candle/ner.rb CHANGED
@@ -30,10 +30,10 @@ module Candle
30
30
  # Load a pre-trained NER model from HuggingFace
31
31
  #
32
32
  # @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
33
- # @param device [Device, nil] Device to run on (defaults to best available)
33
+ # @param device [Device] Device to run on (defaults to best available)
34
34
  # @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
35
35
  # @return [NER] NER instance
36
- def from_pretrained(model_id, device: nil, tokenizer: nil)
36
+ def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
37
37
  new(model_id, device, tokenizer)
38
38
  end
39
39
 
@@ -112,7 +112,7 @@ module Candle
112
112
  # @return [Array<Hash>] Filtered entities of the specified type
113
113
  def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
114
114
  entities = extract_entities(text, confidence_threshold: confidence_threshold)
115
- entities.select { |e| e["label"] == entity_type.upcase }
115
+ entities.select { |e| e[:label] == entity_type.upcase }
116
116
  end
117
117
 
118
118
  # Analyze text and return both entities and token predictions
@@ -137,12 +137,12 @@ module Candle
137
137
  return text if entities.empty?
138
138
 
139
139
  # Sort by start position (reverse for easier insertion)
140
- entities.sort_by! { |e| -e["start"] }
140
+ entities.sort_by! { |e| -e[:start] }
141
141
 
142
142
  result = text.dup
143
143
  entities.each do |entity|
144
- label = "[#{entity['label']}:#{entity['confidence'].round(2)}]"
145
- result.insert(entity["end"], label)
144
+ label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
145
+ result.insert(entity[:end], label)
146
146
  end
147
147
 
148
148
  result
@@ -152,7 +152,19 @@ module Candle
152
152
  #
153
153
  # @return [String] Model description
154
154
  def inspect
155
- "#<Candle::NER #{model_info}>"
155
+ opts = options rescue {}
156
+
157
+ parts = ["#<Candle::NER"]
158
+ parts << "model=#{opts["model_id"] || "unknown"}"
159
+ parts << "device=#{opts["device"] || "unknown"}"
160
+ parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
161
+
162
+ if opts["entity_types"] && !opts["entity_types"].empty?
163
+ types = opts["entity_types"].sort.join(",")
164
+ parts << "types=#{types}"
165
+ end
166
+
167
+ parts.join(" ") + ">"
156
168
  end
157
169
 
158
170
  alias to_s inspect
@@ -186,12 +198,12 @@ module Candle
186
198
  match_end = $~.offset(0)[1]
187
199
 
188
200
  entities << {
189
- "text" => match_text,
190
- "label" => @entity_type,
191
- "start" => match_start,
192
- "end" => match_end,
193
- "confidence" => 1.0,
194
- "source" => "pattern"
201
+ text: match_text,
202
+ label: @entity_type,
203
+ start: match_start,
204
+ end: match_end,
205
+ confidence: 1.0,
206
+ source: "pattern"
195
207
  }
196
208
  end
197
209
  end
@@ -242,12 +254,12 @@ module Candle
242
254
 
243
255
  if word_boundary?(prev_char) && word_boundary?(next_char)
244
256
  entities << {
245
- "text" => text[idx, pattern.length],
246
- "label" => @entity_type,
247
- "start" => idx,
248
- "end" => idx + pattern.length,
249
- "confidence" => 1.0,
250
- "source" => "gazetteer"
257
+ text: text[idx, pattern.length],
258
+ label: @entity_type,
259
+ start: idx,
260
+ end: idx + pattern.length,
261
+ confidence: 1.0,
262
+ source: "gazetteer"
251
263
  }
252
264
  end
253
265
 
@@ -327,19 +339,19 @@ module Candle
327
339
 
328
340
  def merge_entities(entities)
329
341
  # Sort by start position and confidence (descending)
330
- sorted = entities.sort_by { |e| [e["start"], -e["confidence"]] }
342
+ sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
331
343
 
332
344
  merged = []
333
345
  sorted.each do |entity|
334
346
  # Check if entity overlaps with any already merged
335
347
  overlaps = merged.any? do |existing|
336
- entity["start"] < existing["end"] && entity["end"] > existing["start"]
348
+ entity[:start] < existing[:end] && entity[:end] > existing[:start]
337
349
  end
338
350
 
339
351
  merged << entity unless overlaps
340
352
  end
341
353
 
342
- merged.sort_by { |e| e["start"] }
354
+ merged.sort_by { |e| e[:start] }
343
355
  end
344
356
  end
345
357
  end
@@ -3,10 +3,20 @@ module Candle
3
3
  # Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
4
4
  DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
5
5
 
6
+ # Load a pre-trained reranker model from HuggingFace
7
+ # @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
8
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
9
+ # @return [Reranker] A new Reranker instance
10
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best)
11
+ _create(model_id, device)
12
+ end
13
+
6
14
  # Constructor for creating a new Reranker with optional parameters
15
+ # @deprecated Use {.from_pretrained} instead
7
16
  # @param model_path [String, nil] The path to the model on Hugging Face
8
17
  # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
9
- def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.cpu)
18
+ def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best)
19
+ $stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
10
20
  _create(model_path, device)
11
21
  end
12
22
 
@@ -20,5 +30,14 @@ module Candle
20
30
  { doc_id: doc_id, score: score, text: doc }
21
31
  }
22
32
  end
33
+
34
+ # Improved inspect method
35
+ def inspect
36
+ opts = options rescue {}
37
+ parts = ["#<Candle::Reranker"]
38
+ parts << "model=#{opts["model_id"] || "unknown"}"
39
+ parts << "device=#{opts["device"] || "unknown"}"
40
+ parts.join(" ") + ">"
41
+ end
23
42
  end
24
43
  end
data/lib/candle/tensor.rb CHANGED
@@ -51,6 +51,21 @@ module Candle
51
51
  to_f.to_i
52
52
  end
53
53
 
54
+ # Improved inspect method showing shape, dtype, and device
55
+ def inspect
56
+ shape_str = shape.join("x")
57
+
58
+ parts = ["#<Candle::Tensor"]
59
+ parts << "shape=#{shape_str}"
60
+ parts << "dtype=#{dtype}"
61
+ parts << "device=#{device}"
62
+
63
+ # Add element count for clarity
64
+ parts << "elements=#{elem_count}"
65
+
66
+ parts.join(" ") + ">"
67
+ end
68
+
54
69
 
55
70
  # Override class methods to support keyword arguments for device
56
71
  class << self
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.1.1"
3
+ VERSION = "1.2.0"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.1
4
+ version: 1.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-07-28 00:00:00.000000000 Z
12
+ date: 2025-08-10 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -137,6 +137,20 @@ dependencies:
137
137
  - - ">="
138
138
  - !ruby/object:Gem::Version
139
139
  version: '0'
140
+ - !ruby/object:Gem::Dependency
141
+ name: rspec
142
+ requirement: !ruby/object:Gem::Requirement
143
+ requirements:
144
+ - - "~>"
145
+ - !ruby/object:Gem::Version
146
+ version: '3.13'
147
+ type: :development
148
+ prerelease: false
149
+ version_requirements: !ruby/object:Gem::Requirement
150
+ requirements:
151
+ - - "~>"
152
+ - !ruby/object:Gem::Version
153
+ version: '3.13'
140
154
  description: huggingface/candle for Ruby
141
155
  email:
142
156
  - chris@petersen.io
@@ -169,14 +183,14 @@ files:
169
183
  - ext/candle/src/llm/quantized_gguf.rs
170
184
  - ext/candle/src/llm/qwen.rs
171
185
  - ext/candle/src/llm/text_generation.rs
172
- - ext/candle/src/ner.rs
173
- - ext/candle/src/reranker.rs
174
186
  - ext/candle/src/ruby/device.rs
175
187
  - ext/candle/src/ruby/dtype.rs
176
188
  - ext/candle/src/ruby/embedding_model.rs
177
189
  - ext/candle/src/ruby/errors.rs
178
190
  - ext/candle/src/ruby/llm.rs
179
191
  - ext/candle/src/ruby/mod.rs
192
+ - ext/candle/src/ruby/ner.rs
193
+ - ext/candle/src/ruby/reranker.rs
180
194
  - ext/candle/src/ruby/result.rs
181
195
  - ext/candle/src/ruby/structured.rs
182
196
  - ext/candle/src/ruby/tensor.rs
@@ -196,6 +210,8 @@ files:
196
210
  - ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs
197
211
  - ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs
198
212
  - ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs
213
+ - ext/candle/tests/device_tests.rs
214
+ - ext/candle/tests/tensor_tests.rs
199
215
  - lib/candle.rb
200
216
  - lib/candle/build_info.rb
201
217
  - lib/candle/device_utils.rb