red-candle 1.8.0.pre3-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,149 @@
1
+ require 'logger'
2
+
3
+ module Candle
4
+ # Logging functionality for the Red Candle gem
5
+ class << self
6
+ # Get the current logger instance
7
+ # @return [Logger] The logger instance
8
+ def logger
9
+ @logger ||= create_default_logger
10
+ end
11
+
12
+ # Set a custom logger instance
13
+ # @param custom_logger [Logger] A custom logger instance
14
+ def logger=(custom_logger)
15
+ @logger = custom_logger
16
+ end
17
+
18
+ # Configure logging with a block
19
+ # @yield [config] Configuration object
20
+ def configure_logging
21
+ config = LoggerConfig.new
22
+ yield config if block_given?
23
+ @logger = config.build_logger
24
+ end
25
+
26
+ private
27
+
28
+ # Create the default logger with CLI-friendly settings
29
+ # @return [Logger] Configured logger instance
30
+ def create_default_logger
31
+ logger = Logger.new($stderr)
32
+ logger.level = default_log_level
33
+ logger.formatter = cli_friendly_formatter
34
+ logger
35
+ end
36
+
37
+ # Determine default log level based on environment variables
38
+ # @return [Integer] Logger level constant
39
+ def default_log_level
40
+ # Support legacy CANDLE_VERBOSE for backward compatibility, but prefer explicit configuration
41
+ return Logger::DEBUG if ENV['CANDLE_VERBOSE']
42
+ Logger::WARN # CLI-friendly: only show warnings/errors by default
43
+ end
44
+
45
+ # CLI-friendly formatter that outputs just the message
46
+ # @return [Proc] Formatter proc
47
+ def cli_friendly_formatter
48
+ proc { |severity, datetime, progname, msg| "#{msg}\n" }
49
+ end
50
+ end
51
+
52
+ # Configuration helper for logger setup
53
+ class LoggerConfig
54
+ attr_accessor :level, :output, :formatter
55
+
56
+ def initialize
57
+ @level = :warn
58
+ @output = $stderr
59
+ @formatter = :simple
60
+ end
61
+
62
+ # Build a logger from the configuration
63
+ # @return [Logger] Configured logger
64
+ def build_logger
65
+ logger = Logger.new(@output)
66
+ logger.level = normalize_level(@level)
67
+ logger.formatter = build_formatter(@formatter)
68
+ logger
69
+ end
70
+
71
+ # Set log level to debug (verbose output)
72
+ def verbose!
73
+ @level = :debug
74
+ end
75
+
76
+ # Set log level to info
77
+ def info!
78
+ @level = :info
79
+ end
80
+
81
+ # Set log level to warn (default)
82
+ def quiet!
83
+ @level = :warn
84
+ end
85
+
86
+ # Set log level to error (minimal output)
87
+ def silent!
88
+ @level = :error
89
+ end
90
+
91
+ # Log to stdout instead of stderr
92
+ def log_to_stdout!
93
+ @output = $stdout
94
+ end
95
+
96
+ # Log to a file
97
+ # @param file_path [String] Path to log file
98
+ def log_to_file!(file_path)
99
+ @output = file_path
100
+ end
101
+
102
+ # Disable logging completely
103
+ def disable!
104
+ @output = File::NULL
105
+ end
106
+
107
+ private
108
+
109
+ # Convert symbol/string level to Logger constant
110
+ # @param level [Symbol, String, Integer] Log level
111
+ # @return [Integer] Logger level constant
112
+ def normalize_level(level)
113
+ case level.to_s.downcase
114
+ when 'debug' then Logger::DEBUG
115
+ when 'info' then Logger::INFO
116
+ when 'warn', 'warning' then Logger::WARN
117
+ when 'error' then Logger::ERROR
118
+ when 'fatal' then Logger::FATAL
119
+ else Logger::WARN
120
+ end
121
+ end
122
+
123
+ # Build formatter based on type
124
+ # @param formatter_type [Symbol] Type of formatter
125
+ # @return [Proc] Formatter proc
126
+ def build_formatter(formatter_type)
127
+ case formatter_type
128
+ when :simple, :cli
129
+ proc { |severity, datetime, progname, msg| "#{msg}\n" }
130
+ when :detailed
131
+ proc do |severity, datetime, progname, msg|
132
+ "[#{datetime.strftime('%Y-%m-%d %H:%M:%S')}] #{severity}: #{msg}\n"
133
+ end
134
+ when :json
135
+ require 'json'
136
+ proc do |severity, datetime, progname, msg|
137
+ JSON.generate({
138
+ timestamp: datetime.iso8601,
139
+ level: severity,
140
+ message: msg,
141
+ program: progname
142
+ }) + "\n"
143
+ end
144
+ else
145
+ proc { |severity, datetime, progname, msg| "#{msg}\n" }
146
+ end
147
+ end
148
+ end
149
+ end
data/lib/candle/ner.rb ADDED
@@ -0,0 +1,368 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Pattern validation available but not forced
4
+ # require_relative 'pattern_validator' # Uncomment if needed
5
+
6
+ module Candle
7
+ # Named Entity Recognition (NER) for token classification
8
+ #
9
+ # This class provides methods to extract named entities from text using
10
+ # pre-trained BERT-based models. It supports standard NER labels like
11
+ # PER (person), ORG (organization), LOC (location), and can be extended
12
+ # with custom entity types.
13
+ #
14
+ # @example Load a pre-trained NER model
15
+ # ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
16
+ #
17
+ # @example Load a model with a specific tokenizer
18
+ # ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
19
+ #
20
+ # @example Extract entities from text
21
+ # entities = ner.extract_entities("Apple Inc. was founded by Steve Jobs in Cupertino.")
22
+ # # => [
23
+ # # { text: "Apple Inc.", label: "ORG", start: 0, end: 10, confidence: 0.99 },
24
+ # # { text: "Steve Jobs", label: "PER", start: 26, end: 36, confidence: 0.98 },
25
+ # # { text: "Cupertino", label: "LOC", start: 40, end: 49, confidence: 0.97 }
26
+ # # ]
27
+ #
28
+ # @example Get token-level predictions
29
+ # tokens = ner.predict_tokens("John works at Google")
30
+ # # Returns detailed token-by-token predictions with confidence scores
31
+ class NER
32
+ class << self
33
+ # Load a pre-trained NER model from HuggingFace
34
+ #
35
+ # @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
36
+ # @param device [Device] Device to run on (defaults to best available)
37
+ # @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
38
+ # @return [NER] NER instance
39
+ def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
40
+ new(model_id, device, tokenizer)
41
+ end
42
+
43
+ # Popular pre-trained models for different domains
44
+ def suggested_models
45
+ {
46
+ general: {
47
+ model: "Babelscape/wikineural-multilingual-ner",
48
+ note: "Has tokenizer.json"
49
+ },
50
+ general_alt: {
51
+ model: "dslim/bert-base-NER",
52
+ tokenizer: "bert-base-cased",
53
+ note: "Requires separate tokenizer"
54
+ },
55
+ multilingual: {
56
+ model: "Davlan/bert-base-multilingual-cased-ner-hrl",
57
+ note: "Check tokenizer availability"
58
+ },
59
+ biomedical: {
60
+ model: "dmis-lab/biobert-base-cased-v1.2",
61
+ note: "May require specific tokenizer"
62
+ },
63
+ clinical: {
64
+ model: "emilyalsentzer/Bio_ClinicalBERT",
65
+ note: "May require specific tokenizer"
66
+ },
67
+ scientific: {
68
+ model: "allenai/scibert_scivocab_uncased",
69
+ note: "May require specific tokenizer"
70
+ }
71
+ }
72
+ end
73
+ end
74
+
75
+ # Create an alias for the native method
76
+ alias_method :_extract_entities, :extract_entities
77
+
78
+ # Extract entities from text
79
+ #
80
+ # @param text [String] The text to analyze
81
+ # @param confidence_threshold [Float] Minimum confidence score (default: 0.9)
82
+ # @return [Array<Hash>] Array of entity hashes with text, label, start, end, confidence
83
+ def extract_entities(text, confidence_threshold: 0.9)
84
+ # Call the native method with positional arguments
85
+ _extract_entities(text, confidence_threshold)
86
+ end
87
+
88
+ # Get available entity types
89
+ #
90
+ # @return [Array<String>] List of entity types (without B-/I- prefixes)
91
+ def entity_types
92
+ return @entity_types if @entity_types
93
+
94
+ label_config = labels
95
+ @entity_types = label_config["label2id"].keys
96
+ .reject { |l| l == "O" }
97
+ .map { |l| l.sub(/^[BI]-/, "") }
98
+ .uniq
99
+ .sort
100
+ end
101
+
102
+ # Check if model supports a specific entity type
103
+ #
104
+ # @param entity_type [String] Entity type to check (e.g., "GENE", "PER")
105
+ # @return [Boolean] Whether the model recognizes this entity type
106
+ def supports_entity?(entity_type)
107
+ entity_types.include?(entity_type.upcase)
108
+ end
109
+
110
+ # Extract entities of a specific type
111
+ #
112
+ # @param text [String] The text to analyze
113
+ # @param entity_type [String] Entity type to extract (e.g., "PER", "ORG")
114
+ # @param confidence_threshold [Float] Minimum confidence score
115
+ # @return [Array<Hash>] Filtered entities of the specified type
116
+ def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
117
+ entities = extract_entities(text, confidence_threshold: confidence_threshold)
118
+ entities.select { |e| e[:label] == entity_type.upcase }
119
+ end
120
+
121
+ # Analyze text and return both entities and token predictions
122
+ #
123
+ # @param text [String] The text to analyze
124
+ # @param confidence_threshold [Float] Minimum confidence for entities
125
+ # @return [Hash] Hash with :entities and :tokens keys
126
+ def analyze(text, confidence_threshold: 0.9)
127
+ {
128
+ entities: extract_entities(text, confidence_threshold: confidence_threshold),
129
+ tokens: predict_tokens(text)
130
+ }
131
+ end
132
+
133
+ # Get a formatted string representation of entities
134
+ #
135
+ # @param text [String] The text to analyze
136
+ # @param confidence_threshold [Float] Minimum confidence score
137
+ # @return [String] Formatted output with entities highlighted
138
+ def format_entities(text, confidence_threshold: 0.9)
139
+ entities = extract_entities(text, confidence_threshold: confidence_threshold)
140
+ return text if entities.empty?
141
+
142
+ # Sort by start position (reverse for easier insertion)
143
+ entities.sort_by! { |e| -e[:start] }
144
+
145
+ result = text.dup
146
+ entities.each do |entity|
147
+ label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
148
+ result.insert(entity[:end], label)
149
+ end
150
+
151
+ result
152
+ end
153
+
154
+ # Get model information
155
+ #
156
+ # @return [String] Model description
157
+ def inspect
158
+ opts = options rescue {}
159
+
160
+ parts = ["#<Candle::NER"]
161
+ parts << "model=#{opts["model_id"] || "unknown"}"
162
+ parts << "device=#{opts["device"] || "unknown"}"
163
+ parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
164
+
165
+ if opts["entity_types"] && !opts["entity_types"].empty?
166
+ types = opts["entity_types"].sort.join(",")
167
+ parts << "types=#{types}"
168
+ end
169
+
170
+ parts.join(" ") + ">"
171
+ end
172
+
173
+ alias to_s inspect
174
+ end
175
+
176
+ # Pattern-based entity recognizer for custom entities
177
+ class PatternEntityRecognizer
178
+ attr_reader :patterns, :entity_type
179
+
180
+ def initialize(entity_type, patterns = [])
181
+ @entity_type = entity_type
182
+ @patterns = patterns
183
+ end
184
+
185
+ # Add a pattern (String or Regexp)
186
+ def add_pattern(pattern)
187
+ @patterns << pattern
188
+ self
189
+ end
190
+
191
+ # Recognize entities using patterns
192
+ def recognize(text, tokenizer = nil)
193
+ entities = []
194
+
195
+ # Limit text length to prevent ReDoS on very long strings
196
+ # This is especially important for Ruby < 3.2
197
+ max_length = 1_000_000 # 1MB of text
198
+ if text.length > max_length
199
+ Candle.logger.warn "PatternEntityRecognizer: Text truncated from #{text.length} to #{max_length} chars for safety"
200
+ text = text[0...max_length]
201
+ end
202
+
203
+ @patterns.each do |pattern|
204
+ regex = pattern.is_a?(Regexp) ? pattern : Regexp.new(pattern)
205
+
206
+ text.scan(regex) do |match|
207
+ match_text = $&
208
+ match_start = $~.offset(0)[0]
209
+ match_end = $~.offset(0)[1]
210
+
211
+ entities << {
212
+ text: match_text,
213
+ label: @entity_type,
214
+ start: match_start,
215
+ end: match_end,
216
+ confidence: 1.0,
217
+ source: "pattern"
218
+ }
219
+ end
220
+ end
221
+
222
+ entities
223
+ end
224
+ end
225
+
226
+ # Gazetteer-based entity recognizer
227
+ class GazetteerEntityRecognizer
228
+ attr_reader :entity_type, :terms, :case_sensitive
229
+
230
+ def initialize(entity_type, terms = [], case_sensitive: false)
231
+ @entity_type = entity_type
232
+ @case_sensitive = case_sensitive
233
+ @terms = build_term_set(terms)
234
+ end
235
+
236
+ # Add terms to the gazetteer
237
+ def add_terms(terms)
238
+ terms = [terms] unless terms.is_a?(Array)
239
+ terms.each { |term| @terms.add(normalize_term(term)) }
240
+ self
241
+ end
242
+
243
+ # Load terms from file
244
+ def load_from_file(filepath)
245
+ File.readlines(filepath).each do |line|
246
+ term = line.strip
247
+ add_terms(term) unless term.empty? || term.start_with?("#")
248
+ end
249
+ self
250
+ end
251
+
252
+ # Recognize entities using the gazetteer
253
+ def recognize(text, tokenizer = nil)
254
+ entities = []
255
+ normalized_text = @case_sensitive ? text : text.downcase
256
+
257
+ @terms.each do |term|
258
+ pattern = @case_sensitive ? term : term.downcase
259
+ pos = 0
260
+
261
+ while (idx = normalized_text.index(pattern, pos))
262
+ # Check word boundaries
263
+ prev_char = idx > 0 ? text[idx - 1] : " "
264
+ next_char = idx + pattern.length < text.length ? text[idx + pattern.length] : " "
265
+
266
+ if word_boundary?(prev_char) && word_boundary?(next_char)
267
+ entities << {
268
+ text: text[idx, pattern.length],
269
+ label: @entity_type,
270
+ start: idx,
271
+ end: idx + pattern.length,
272
+ confidence: 1.0,
273
+ source: "gazetteer"
274
+ }
275
+ end
276
+
277
+ pos = idx + 1
278
+ end
279
+ end
280
+
281
+ entities
282
+ end
283
+
284
+ private
285
+
286
+ def build_term_set(terms)
287
+ Set.new(terms.map { |term| normalize_term(term) })
288
+ end
289
+
290
+ def normalize_term(term)
291
+ @case_sensitive ? term : term.downcase
292
+ end
293
+
294
+ def word_boundary?(char)
295
+ char.match?(/\W/)
296
+ end
297
+ end
298
+
299
+ # Hybrid NER that combines ML model with rules
300
+ class HybridNER
301
+ attr_reader :model_ner, :pattern_recognizers, :gazetteer_recognizers
302
+
303
+ def initialize(model_id = nil, device: nil)
304
+ @model_ner = model_id ? NER.from_pretrained(model_id, device: device) : nil
305
+ @pattern_recognizers = []
306
+ @gazetteer_recognizers = []
307
+ end
308
+
309
+ # Add a pattern-based recognizer
310
+ def add_pattern_recognizer(entity_type, patterns)
311
+ recognizer = PatternEntityRecognizer.new(entity_type, patterns)
312
+ @pattern_recognizers << recognizer
313
+ self
314
+ end
315
+
316
+ # Add a gazetteer-based recognizer
317
+ def add_gazetteer_recognizer(entity_type, terms, **options)
318
+ recognizer = GazetteerEntityRecognizer.new(entity_type, terms, **options)
319
+ @gazetteer_recognizers << recognizer
320
+ self
321
+ end
322
+
323
+ # Extract entities using all recognizers
324
+ def extract_entities(text, confidence_threshold: 0.9)
325
+ all_entities = []
326
+
327
+ # Model-based entities
328
+ if @model_ner
329
+ model_entities = @model_ner.extract_entities(text, confidence_threshold: confidence_threshold)
330
+ all_entities.concat(model_entities)
331
+ end
332
+
333
+ # Pattern-based entities
334
+ @pattern_recognizers.each do |recognizer|
335
+ pattern_entities = recognizer.recognize(text)
336
+ all_entities.concat(pattern_entities)
337
+ end
338
+
339
+ # Gazetteer-based entities
340
+ @gazetteer_recognizers.each do |recognizer|
341
+ gazetteer_entities = recognizer.recognize(text)
342
+ all_entities.concat(gazetteer_entities)
343
+ end
344
+
345
+ # Merge overlapping entities (prefer highest confidence)
346
+ merge_entities(all_entities)
347
+ end
348
+
349
+ private
350
+
351
+ def merge_entities(entities)
352
+ # Sort by start position and confidence (descending)
353
+ sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
354
+
355
+ merged = []
356
+ sorted.each do |entity|
357
+ # Check if entity overlaps with any already merged
358
+ overlaps = merged.any? do |existing|
359
+ entity[:start] < existing[:end] && entity[:end] > existing[:start]
360
+ end
361
+
362
+ merged << entity unless overlaps
363
+ end
364
+
365
+ merged.sort_by { |e| e[:start] }
366
+ end
367
+ end
368
+ end
@@ -0,0 +1,45 @@
1
+ module Candle
2
+ class Reranker
3
+ # Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
4
+ DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
5
+
6
+ # Load a pre-trained reranker model from HuggingFace
7
+ # @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
8
+ # @param device [Candle::Device] The device to use for computation (defaults to best available)
9
+ # @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
10
+ # @return [Reranker] A new Reranker instance
11
+ def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
12
+ _create(model_id, device, max_length)
13
+ end
14
+
15
+ # Constructor for creating a new Reranker with optional parameters
16
+ # @deprecated Use {.from_pretrained} instead
17
+ # @param model_path [String, nil] The path to the model on Hugging Face
18
+ # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
19
+ # @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
20
+ def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
21
+ $stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
22
+ _create(model_path, device, max_length)
23
+ end
24
+
25
+ # Returns documents ranked by relevance using the specified pooling method.
26
+ # @param query [String] The input text
27
+ # @param documents [Array<String>] The list of documents to compare against
28
+ # @param pooling_method [String] Pooling method: "pooler", "cls", or "mean". Default: "pooler"
29
+ # @param apply_sigmoid [Boolean] Whether to apply sigmoid to the scores. Default: true
30
+ def rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
31
+ rerank_with_options(query, documents, pooling_method, apply_sigmoid).collect { |doc, score, doc_id|
32
+ { doc_id: doc_id, score: score, text: doc }
33
+ }
34
+ end
35
+
36
+ # Improved inspect method
37
+ def inspect
38
+ opts = options rescue {}
39
+ parts = ["#<Candle::Reranker"]
40
+ parts << "model=#{opts["model_id"] || "unknown"}"
41
+ parts << "device=#{opts["device"] || "unknown"}"
42
+ parts.join(" ") + ">"
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,99 @@
1
+ module Candle
2
+ class Tensor
3
+ include Enumerable
4
+
5
+ def each
6
+ case self.rank
7
+ when 0
8
+ # Scalar tensor - yield the single value
9
+ yield self.item
10
+ when 1
11
+ # 1D tensor - yield each value
12
+ # Check if we can use f32 values to avoid conversion
13
+ if dtype.to_s.downcase == "f32"
14
+ begin
15
+ values_f32.each { |value| yield value }
16
+ rescue NoMethodError
17
+ # :nocov:
18
+ # If values_f32 isn't available yet (not recompiled), fall back
19
+ if device.to_s != "cpu"
20
+ # Move to CPU to avoid Metal F32->F64 conversion issue
21
+ to_device(Candle::Device.cpu).values.each { |value| yield value }
22
+ else
23
+ values.each { |value| yield value }
24
+ end
25
+ # :nocov:
26
+ end
27
+ else
28
+ # For non-F32 dtypes, use regular values
29
+ values.each { |value| yield value }
30
+ end
31
+ else
32
+ # Multi-dimensional tensor - yield each sub-tensor
33
+ shape.first.times do |i|
34
+ yield self[i]
35
+ end
36
+ end
37
+ end
38
+
39
+ # Convert scalar tensor to float
40
+ def to_f
41
+ if rank == 0
42
+ # Use item method which handles dtype conversion properly
43
+ item
44
+ else
45
+ raise ArgumentError, "to_f can only be called on scalar tensors (rank 0), but this tensor has rank #{rank}"
46
+ end
47
+ end
48
+
49
+ # Convert scalar tensor to integer
50
+ def to_i
51
+ to_f.to_i
52
+ end
53
+
54
+ # Improved inspect method showing shape, dtype, and device
55
+ def inspect
56
+ shape_str = shape.join("x")
57
+
58
+ parts = ["#<Candle::Tensor"]
59
+ parts << "shape=#{shape_str}"
60
+ parts << "dtype=#{dtype}"
61
+ parts << "device=#{device}"
62
+
63
+ # Add element count for clarity
64
+ parts << "elements=#{elem_count}"
65
+
66
+ parts.join(" ") + ">"
67
+ end
68
+
69
+
70
+ # Override class methods to support keyword arguments for device
71
+ class << self
72
+ alias_method :_original_new, :new
73
+ alias_method :_original_ones, :ones
74
+ alias_method :_original_zeros, :zeros
75
+ alias_method :_original_rand, :rand
76
+ alias_method :_original_randn, :randn
77
+
78
+ def new(data, dtype = nil, device: nil)
79
+ _original_new(data, dtype, device)
80
+ end
81
+
82
+ def ones(shape, device: nil)
83
+ _original_ones(shape, device)
84
+ end
85
+
86
+ def zeros(shape, device: nil)
87
+ _original_zeros(shape, device)
88
+ end
89
+
90
+ def rand(shape, device: nil)
91
+ _original_rand(shape, device)
92
+ end
93
+
94
+ def randn(shape, device: nil)
95
+ _original_randn(shape, device)
96
+ end
97
+ end
98
+ end
99
+ end