red-candle 1.8.0.pre2-x86_64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5193 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +33 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
require 'logger'
|
|
2
|
+
|
|
3
|
+
module Candle
|
|
4
|
+
# Logging functionality for the Red Candle gem
|
|
5
|
+
class << self
|
|
6
|
+
# Get the current logger instance
|
|
7
|
+
# @return [Logger] The logger instance
|
|
8
|
+
def logger
|
|
9
|
+
@logger ||= create_default_logger
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
# Set a custom logger instance
|
|
13
|
+
# @param custom_logger [Logger] A custom logger instance
|
|
14
|
+
def logger=(custom_logger)
|
|
15
|
+
@logger = custom_logger
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Configure logging with a block
|
|
19
|
+
# @yield [config] Configuration object
|
|
20
|
+
def configure_logging
|
|
21
|
+
config = LoggerConfig.new
|
|
22
|
+
yield config if block_given?
|
|
23
|
+
@logger = config.build_logger
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
private
|
|
27
|
+
|
|
28
|
+
# Create the default logger with CLI-friendly settings
|
|
29
|
+
# @return [Logger] Configured logger instance
|
|
30
|
+
def create_default_logger
|
|
31
|
+
logger = Logger.new($stderr)
|
|
32
|
+
logger.level = default_log_level
|
|
33
|
+
logger.formatter = cli_friendly_formatter
|
|
34
|
+
logger
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
# Determine default log level based on environment variables
|
|
38
|
+
# @return [Integer] Logger level constant
|
|
39
|
+
def default_log_level
|
|
40
|
+
# Support legacy CANDLE_VERBOSE for backward compatibility, but prefer explicit configuration
|
|
41
|
+
return Logger::DEBUG if ENV['CANDLE_VERBOSE']
|
|
42
|
+
Logger::WARN # CLI-friendly: only show warnings/errors by default
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# CLI-friendly formatter that outputs just the message
|
|
46
|
+
# @return [Proc] Formatter proc
|
|
47
|
+
def cli_friendly_formatter
|
|
48
|
+
proc { |severity, datetime, progname, msg| "#{msg}\n" }
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Configuration helper for logger setup
|
|
53
|
+
class LoggerConfig
|
|
54
|
+
attr_accessor :level, :output, :formatter
|
|
55
|
+
|
|
56
|
+
def initialize
|
|
57
|
+
@level = :warn
|
|
58
|
+
@output = $stderr
|
|
59
|
+
@formatter = :simple
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Build a logger from the configuration
|
|
63
|
+
# @return [Logger] Configured logger
|
|
64
|
+
def build_logger
|
|
65
|
+
logger = Logger.new(@output)
|
|
66
|
+
logger.level = normalize_level(@level)
|
|
67
|
+
logger.formatter = build_formatter(@formatter)
|
|
68
|
+
logger
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
# Set log level to debug (verbose output)
|
|
72
|
+
def verbose!
|
|
73
|
+
@level = :debug
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
# Set log level to info
|
|
77
|
+
def info!
|
|
78
|
+
@level = :info
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Set log level to warn (default)
|
|
82
|
+
def quiet!
|
|
83
|
+
@level = :warn
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
# Set log level to error (minimal output)
|
|
87
|
+
def silent!
|
|
88
|
+
@level = :error
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# Log to stdout instead of stderr
|
|
92
|
+
def log_to_stdout!
|
|
93
|
+
@output = $stdout
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
# Log to a file
|
|
97
|
+
# @param file_path [String] Path to log file
|
|
98
|
+
def log_to_file!(file_path)
|
|
99
|
+
@output = file_path
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Disable logging completely
|
|
103
|
+
def disable!
|
|
104
|
+
@output = File::NULL
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
private
|
|
108
|
+
|
|
109
|
+
# Convert symbol/string level to Logger constant
|
|
110
|
+
# @param level [Symbol, String, Integer] Log level
|
|
111
|
+
# @return [Integer] Logger level constant
|
|
112
|
+
def normalize_level(level)
|
|
113
|
+
case level.to_s.downcase
|
|
114
|
+
when 'debug' then Logger::DEBUG
|
|
115
|
+
when 'info' then Logger::INFO
|
|
116
|
+
when 'warn', 'warning' then Logger::WARN
|
|
117
|
+
when 'error' then Logger::ERROR
|
|
118
|
+
when 'fatal' then Logger::FATAL
|
|
119
|
+
else Logger::WARN
|
|
120
|
+
end
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
# Build formatter based on type
|
|
124
|
+
# @param formatter_type [Symbol] Type of formatter
|
|
125
|
+
# @return [Proc] Formatter proc
|
|
126
|
+
def build_formatter(formatter_type)
|
|
127
|
+
case formatter_type
|
|
128
|
+
when :simple, :cli
|
|
129
|
+
proc { |severity, datetime, progname, msg| "#{msg}\n" }
|
|
130
|
+
when :detailed
|
|
131
|
+
proc do |severity, datetime, progname, msg|
|
|
132
|
+
"[#{datetime.strftime('%Y-%m-%d %H:%M:%S')}] #{severity}: #{msg}\n"
|
|
133
|
+
end
|
|
134
|
+
when :json
|
|
135
|
+
require 'json'
|
|
136
|
+
proc do |severity, datetime, progname, msg|
|
|
137
|
+
JSON.generate({
|
|
138
|
+
timestamp: datetime.iso8601,
|
|
139
|
+
level: severity,
|
|
140
|
+
message: msg,
|
|
141
|
+
program: progname
|
|
142
|
+
}) + "\n"
|
|
143
|
+
end
|
|
144
|
+
else
|
|
145
|
+
proc { |severity, datetime, progname, msg| "#{msg}\n" }
|
|
146
|
+
end
|
|
147
|
+
end
|
|
148
|
+
end
|
|
149
|
+
end
|
data/lib/candle/ner.rb
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Pattern validation available but not forced
|
|
4
|
+
# require_relative 'pattern_validator' # Uncomment if needed
|
|
5
|
+
|
|
6
|
+
module Candle
|
|
7
|
+
# Named Entity Recognition (NER) for token classification
|
|
8
|
+
#
|
|
9
|
+
# This class provides methods to extract named entities from text using
|
|
10
|
+
# pre-trained BERT-based models. It supports standard NER labels like
|
|
11
|
+
# PER (person), ORG (organization), LOC (location), and can be extended
|
|
12
|
+
# with custom entity types.
|
|
13
|
+
#
|
|
14
|
+
# @example Load a pre-trained NER model
|
|
15
|
+
# ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
|
|
16
|
+
#
|
|
17
|
+
# @example Load a model with a specific tokenizer
|
|
18
|
+
# ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
|
|
19
|
+
#
|
|
20
|
+
# @example Extract entities from text
|
|
21
|
+
# entities = ner.extract_entities("Apple Inc. was founded by Steve Jobs in Cupertino.")
|
|
22
|
+
# # => [
|
|
23
|
+
# # { text: "Apple Inc.", label: "ORG", start: 0, end: 10, confidence: 0.99 },
|
|
24
|
+
# # { text: "Steve Jobs", label: "PER", start: 26, end: 36, confidence: 0.98 },
|
|
25
|
+
# # { text: "Cupertino", label: "LOC", start: 40, end: 49, confidence: 0.97 }
|
|
26
|
+
# # ]
|
|
27
|
+
#
|
|
28
|
+
# @example Get token-level predictions
|
|
29
|
+
# tokens = ner.predict_tokens("John works at Google")
|
|
30
|
+
# # Returns detailed token-by-token predictions with confidence scores
|
|
31
|
+
class NER
|
|
32
|
+
class << self
|
|
33
|
+
# Load a pre-trained NER model from HuggingFace
|
|
34
|
+
#
|
|
35
|
+
# @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
|
|
36
|
+
# @param device [Device] Device to run on (defaults to best available)
|
|
37
|
+
# @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
|
|
38
|
+
# @return [NER] NER instance
|
|
39
|
+
def from_pretrained(model_id, device: Candle::Device.best, tokenizer: nil)
|
|
40
|
+
new(model_id, device, tokenizer)
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
# Popular pre-trained models for different domains
|
|
44
|
+
def suggested_models
|
|
45
|
+
{
|
|
46
|
+
general: {
|
|
47
|
+
model: "Babelscape/wikineural-multilingual-ner",
|
|
48
|
+
note: "Has tokenizer.json"
|
|
49
|
+
},
|
|
50
|
+
general_alt: {
|
|
51
|
+
model: "dslim/bert-base-NER",
|
|
52
|
+
tokenizer: "bert-base-cased",
|
|
53
|
+
note: "Requires separate tokenizer"
|
|
54
|
+
},
|
|
55
|
+
multilingual: {
|
|
56
|
+
model: "Davlan/bert-base-multilingual-cased-ner-hrl",
|
|
57
|
+
note: "Check tokenizer availability"
|
|
58
|
+
},
|
|
59
|
+
biomedical: {
|
|
60
|
+
model: "dmis-lab/biobert-base-cased-v1.2",
|
|
61
|
+
note: "May require specific tokenizer"
|
|
62
|
+
},
|
|
63
|
+
clinical: {
|
|
64
|
+
model: "emilyalsentzer/Bio_ClinicalBERT",
|
|
65
|
+
note: "May require specific tokenizer"
|
|
66
|
+
},
|
|
67
|
+
scientific: {
|
|
68
|
+
model: "allenai/scibert_scivocab_uncased",
|
|
69
|
+
note: "May require specific tokenizer"
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Create an alias for the native method
|
|
76
|
+
alias_method :_extract_entities, :extract_entities
|
|
77
|
+
|
|
78
|
+
# Extract entities from text
|
|
79
|
+
#
|
|
80
|
+
# @param text [String] The text to analyze
|
|
81
|
+
# @param confidence_threshold [Float] Minimum confidence score (default: 0.9)
|
|
82
|
+
# @return [Array<Hash>] Array of entity hashes with text, label, start, end, confidence
|
|
83
|
+
def extract_entities(text, confidence_threshold: 0.9)
|
|
84
|
+
# Call the native method with positional arguments
|
|
85
|
+
_extract_entities(text, confidence_threshold)
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# Get available entity types
|
|
89
|
+
#
|
|
90
|
+
# @return [Array<String>] List of entity types (without B-/I- prefixes)
|
|
91
|
+
def entity_types
|
|
92
|
+
return @entity_types if @entity_types
|
|
93
|
+
|
|
94
|
+
label_config = labels
|
|
95
|
+
@entity_types = label_config["label2id"].keys
|
|
96
|
+
.reject { |l| l == "O" }
|
|
97
|
+
.map { |l| l.sub(/^[BI]-/, "") }
|
|
98
|
+
.uniq
|
|
99
|
+
.sort
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Check if model supports a specific entity type
|
|
103
|
+
#
|
|
104
|
+
# @param entity_type [String] Entity type to check (e.g., "GENE", "PER")
|
|
105
|
+
# @return [Boolean] Whether the model recognizes this entity type
|
|
106
|
+
def supports_entity?(entity_type)
|
|
107
|
+
entity_types.include?(entity_type.upcase)
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Extract entities of a specific type
|
|
111
|
+
#
|
|
112
|
+
# @param text [String] The text to analyze
|
|
113
|
+
# @param entity_type [String] Entity type to extract (e.g., "PER", "ORG")
|
|
114
|
+
# @param confidence_threshold [Float] Minimum confidence score
|
|
115
|
+
# @return [Array<Hash>] Filtered entities of the specified type
|
|
116
|
+
def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
|
|
117
|
+
entities = extract_entities(text, confidence_threshold: confidence_threshold)
|
|
118
|
+
entities.select { |e| e[:label] == entity_type.upcase }
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
# Analyze text and return both entities and token predictions
|
|
122
|
+
#
|
|
123
|
+
# @param text [String] The text to analyze
|
|
124
|
+
# @param confidence_threshold [Float] Minimum confidence for entities
|
|
125
|
+
# @return [Hash] Hash with :entities and :tokens keys
|
|
126
|
+
def analyze(text, confidence_threshold: 0.9)
|
|
127
|
+
{
|
|
128
|
+
entities: extract_entities(text, confidence_threshold: confidence_threshold),
|
|
129
|
+
tokens: predict_tokens(text)
|
|
130
|
+
}
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
# Get a formatted string representation of entities
|
|
134
|
+
#
|
|
135
|
+
# @param text [String] The text to analyze
|
|
136
|
+
# @param confidence_threshold [Float] Minimum confidence score
|
|
137
|
+
# @return [String] Formatted output with entities highlighted
|
|
138
|
+
def format_entities(text, confidence_threshold: 0.9)
|
|
139
|
+
entities = extract_entities(text, confidence_threshold: confidence_threshold)
|
|
140
|
+
return text if entities.empty?
|
|
141
|
+
|
|
142
|
+
# Sort by start position (reverse for easier insertion)
|
|
143
|
+
entities.sort_by! { |e| -e[:start] }
|
|
144
|
+
|
|
145
|
+
result = text.dup
|
|
146
|
+
entities.each do |entity|
|
|
147
|
+
label = "[#{entity[:label]}:#{entity[:confidence].round(2)}]"
|
|
148
|
+
result.insert(entity[:end], label)
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
result
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
# Get model information
|
|
155
|
+
#
|
|
156
|
+
# @return [String] Model description
|
|
157
|
+
def inspect
|
|
158
|
+
opts = options rescue {}
|
|
159
|
+
|
|
160
|
+
parts = ["#<Candle::NER"]
|
|
161
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
|
162
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
|
163
|
+
parts << "labels=#{opts["num_labels"]}" if opts["num_labels"]
|
|
164
|
+
|
|
165
|
+
if opts["entity_types"] && !opts["entity_types"].empty?
|
|
166
|
+
types = opts["entity_types"].sort.join(",")
|
|
167
|
+
parts << "types=#{types}"
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
parts.join(" ") + ">"
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
alias to_s inspect
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
# Pattern-based entity recognizer for custom entities
|
|
177
|
+
class PatternEntityRecognizer
|
|
178
|
+
attr_reader :patterns, :entity_type
|
|
179
|
+
|
|
180
|
+
def initialize(entity_type, patterns = [])
|
|
181
|
+
@entity_type = entity_type
|
|
182
|
+
@patterns = patterns
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
# Add a pattern (String or Regexp)
|
|
186
|
+
def add_pattern(pattern)
|
|
187
|
+
@patterns << pattern
|
|
188
|
+
self
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
# Recognize entities using patterns
|
|
192
|
+
def recognize(text, tokenizer = nil)
|
|
193
|
+
entities = []
|
|
194
|
+
|
|
195
|
+
# Limit text length to prevent ReDoS on very long strings
|
|
196
|
+
# This is especially important for Ruby < 3.2
|
|
197
|
+
max_length = 1_000_000 # 1MB of text
|
|
198
|
+
if text.length > max_length
|
|
199
|
+
Candle.logger.warn "PatternEntityRecognizer: Text truncated from #{text.length} to #{max_length} chars for safety"
|
|
200
|
+
text = text[0...max_length]
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
@patterns.each do |pattern|
|
|
204
|
+
regex = pattern.is_a?(Regexp) ? pattern : Regexp.new(pattern)
|
|
205
|
+
|
|
206
|
+
text.scan(regex) do |match|
|
|
207
|
+
match_text = $&
|
|
208
|
+
match_start = $~.offset(0)[0]
|
|
209
|
+
match_end = $~.offset(0)[1]
|
|
210
|
+
|
|
211
|
+
entities << {
|
|
212
|
+
text: match_text,
|
|
213
|
+
label: @entity_type,
|
|
214
|
+
start: match_start,
|
|
215
|
+
end: match_end,
|
|
216
|
+
confidence: 1.0,
|
|
217
|
+
source: "pattern"
|
|
218
|
+
}
|
|
219
|
+
end
|
|
220
|
+
end
|
|
221
|
+
|
|
222
|
+
entities
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
# Gazetteer-based entity recognizer
|
|
227
|
+
class GazetteerEntityRecognizer
|
|
228
|
+
attr_reader :entity_type, :terms, :case_sensitive
|
|
229
|
+
|
|
230
|
+
def initialize(entity_type, terms = [], case_sensitive: false)
|
|
231
|
+
@entity_type = entity_type
|
|
232
|
+
@case_sensitive = case_sensitive
|
|
233
|
+
@terms = build_term_set(terms)
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
# Add terms to the gazetteer
|
|
237
|
+
def add_terms(terms)
|
|
238
|
+
terms = [terms] unless terms.is_a?(Array)
|
|
239
|
+
terms.each { |term| @terms.add(normalize_term(term)) }
|
|
240
|
+
self
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
# Load terms from file
|
|
244
|
+
def load_from_file(filepath)
|
|
245
|
+
File.readlines(filepath).each do |line|
|
|
246
|
+
term = line.strip
|
|
247
|
+
add_terms(term) unless term.empty? || term.start_with?("#")
|
|
248
|
+
end
|
|
249
|
+
self
|
|
250
|
+
end
|
|
251
|
+
|
|
252
|
+
# Recognize entities using the gazetteer
|
|
253
|
+
def recognize(text, tokenizer = nil)
|
|
254
|
+
entities = []
|
|
255
|
+
normalized_text = @case_sensitive ? text : text.downcase
|
|
256
|
+
|
|
257
|
+
@terms.each do |term|
|
|
258
|
+
pattern = @case_sensitive ? term : term.downcase
|
|
259
|
+
pos = 0
|
|
260
|
+
|
|
261
|
+
while (idx = normalized_text.index(pattern, pos))
|
|
262
|
+
# Check word boundaries
|
|
263
|
+
prev_char = idx > 0 ? text[idx - 1] : " "
|
|
264
|
+
next_char = idx + pattern.length < text.length ? text[idx + pattern.length] : " "
|
|
265
|
+
|
|
266
|
+
if word_boundary?(prev_char) && word_boundary?(next_char)
|
|
267
|
+
entities << {
|
|
268
|
+
text: text[idx, pattern.length],
|
|
269
|
+
label: @entity_type,
|
|
270
|
+
start: idx,
|
|
271
|
+
end: idx + pattern.length,
|
|
272
|
+
confidence: 1.0,
|
|
273
|
+
source: "gazetteer"
|
|
274
|
+
}
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
pos = idx + 1
|
|
278
|
+
end
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
entities
|
|
282
|
+
end
|
|
283
|
+
|
|
284
|
+
private
|
|
285
|
+
|
|
286
|
+
def build_term_set(terms)
|
|
287
|
+
Set.new(terms.map { |term| normalize_term(term) })
|
|
288
|
+
end
|
|
289
|
+
|
|
290
|
+
def normalize_term(term)
|
|
291
|
+
@case_sensitive ? term : term.downcase
|
|
292
|
+
end
|
|
293
|
+
|
|
294
|
+
def word_boundary?(char)
|
|
295
|
+
char.match?(/\W/)
|
|
296
|
+
end
|
|
297
|
+
end
|
|
298
|
+
|
|
299
|
+
# Hybrid NER that combines ML model with rules
|
|
300
|
+
class HybridNER
|
|
301
|
+
attr_reader :model_ner, :pattern_recognizers, :gazetteer_recognizers
|
|
302
|
+
|
|
303
|
+
def initialize(model_id = nil, device: nil)
|
|
304
|
+
@model_ner = model_id ? NER.from_pretrained(model_id, device: device) : nil
|
|
305
|
+
@pattern_recognizers = []
|
|
306
|
+
@gazetteer_recognizers = []
|
|
307
|
+
end
|
|
308
|
+
|
|
309
|
+
# Add a pattern-based recognizer
|
|
310
|
+
def add_pattern_recognizer(entity_type, patterns)
|
|
311
|
+
recognizer = PatternEntityRecognizer.new(entity_type, patterns)
|
|
312
|
+
@pattern_recognizers << recognizer
|
|
313
|
+
self
|
|
314
|
+
end
|
|
315
|
+
|
|
316
|
+
# Add a gazetteer-based recognizer
|
|
317
|
+
def add_gazetteer_recognizer(entity_type, terms, **options)
|
|
318
|
+
recognizer = GazetteerEntityRecognizer.new(entity_type, terms, **options)
|
|
319
|
+
@gazetteer_recognizers << recognizer
|
|
320
|
+
self
|
|
321
|
+
end
|
|
322
|
+
|
|
323
|
+
# Extract entities using all recognizers
|
|
324
|
+
def extract_entities(text, confidence_threshold: 0.9)
|
|
325
|
+
all_entities = []
|
|
326
|
+
|
|
327
|
+
# Model-based entities
|
|
328
|
+
if @model_ner
|
|
329
|
+
model_entities = @model_ner.extract_entities(text, confidence_threshold: confidence_threshold)
|
|
330
|
+
all_entities.concat(model_entities)
|
|
331
|
+
end
|
|
332
|
+
|
|
333
|
+
# Pattern-based entities
|
|
334
|
+
@pattern_recognizers.each do |recognizer|
|
|
335
|
+
pattern_entities = recognizer.recognize(text)
|
|
336
|
+
all_entities.concat(pattern_entities)
|
|
337
|
+
end
|
|
338
|
+
|
|
339
|
+
# Gazetteer-based entities
|
|
340
|
+
@gazetteer_recognizers.each do |recognizer|
|
|
341
|
+
gazetteer_entities = recognizer.recognize(text)
|
|
342
|
+
all_entities.concat(gazetteer_entities)
|
|
343
|
+
end
|
|
344
|
+
|
|
345
|
+
# Merge overlapping entities (prefer highest confidence)
|
|
346
|
+
merge_entities(all_entities)
|
|
347
|
+
end
|
|
348
|
+
|
|
349
|
+
private
|
|
350
|
+
|
|
351
|
+
def merge_entities(entities)
|
|
352
|
+
# Sort by start position and confidence (descending)
|
|
353
|
+
sorted = entities.sort_by { |e| [e[:start], -e[:confidence]] }
|
|
354
|
+
|
|
355
|
+
merged = []
|
|
356
|
+
sorted.each do |entity|
|
|
357
|
+
# Check if entity overlaps with any already merged
|
|
358
|
+
overlaps = merged.any? do |existing|
|
|
359
|
+
entity[:start] < existing[:end] && entity[:end] > existing[:start]
|
|
360
|
+
end
|
|
361
|
+
|
|
362
|
+
merged << entity unless overlaps
|
|
363
|
+
end
|
|
364
|
+
|
|
365
|
+
merged.sort_by { |e| e[:start] }
|
|
366
|
+
end
|
|
367
|
+
end
|
|
368
|
+
end
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
module Candle
|
|
2
|
+
class Reranker
|
|
3
|
+
# Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
|
|
4
|
+
DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
|
5
|
+
|
|
6
|
+
# Load a pre-trained reranker model from HuggingFace
|
|
7
|
+
# @param model_id [String] HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)
|
|
8
|
+
# @param device [Candle::Device] The device to use for computation (defaults to best available)
|
|
9
|
+
# @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
|
|
10
|
+
# @return [Reranker] A new Reranker instance
|
|
11
|
+
def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
|
|
12
|
+
_create(model_id, device, max_length)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# Constructor for creating a new Reranker with optional parameters
|
|
16
|
+
# @deprecated Use {.from_pretrained} instead
|
|
17
|
+
# @param model_path [String, nil] The path to the model on Hugging Face
|
|
18
|
+
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
|
19
|
+
# @param max_length [Integer] Maximum sequence length for truncation (defaults to 512)
|
|
20
|
+
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
|
|
21
|
+
$stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
|
|
22
|
+
_create(model_path, device, max_length)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
# Returns documents ranked by relevance using the specified pooling method.
|
|
26
|
+
# @param query [String] The input text
|
|
27
|
+
# @param documents [Array<String>] The list of documents to compare against
|
|
28
|
+
# @param pooling_method [String] Pooling method: "pooler", "cls", or "mean". Default: "pooler"
|
|
29
|
+
# @param apply_sigmoid [Boolean] Whether to apply sigmoid to the scores. Default: true
|
|
30
|
+
def rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
|
|
31
|
+
rerank_with_options(query, documents, pooling_method, apply_sigmoid).collect { |doc, score, doc_id|
|
|
32
|
+
{ doc_id: doc_id, score: score, text: doc }
|
|
33
|
+
}
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# Improved inspect method
|
|
37
|
+
def inspect
|
|
38
|
+
opts = options rescue {}
|
|
39
|
+
parts = ["#<Candle::Reranker"]
|
|
40
|
+
parts << "model=#{opts["model_id"] || "unknown"}"
|
|
41
|
+
parts << "device=#{opts["device"] || "unknown"}"
|
|
42
|
+
parts.join(" ") + ">"
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
end
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
module Candle
|
|
2
|
+
class Tensor
|
|
3
|
+
include Enumerable
|
|
4
|
+
|
|
5
|
+
def each
|
|
6
|
+
case self.rank
|
|
7
|
+
when 0
|
|
8
|
+
# Scalar tensor - yield the single value
|
|
9
|
+
yield self.item
|
|
10
|
+
when 1
|
|
11
|
+
# 1D tensor - yield each value
|
|
12
|
+
# Check if we can use f32 values to avoid conversion
|
|
13
|
+
if dtype.to_s.downcase == "f32"
|
|
14
|
+
begin
|
|
15
|
+
values_f32.each { |value| yield value }
|
|
16
|
+
rescue NoMethodError
|
|
17
|
+
# :nocov:
|
|
18
|
+
# If values_f32 isn't available yet (not recompiled), fall back
|
|
19
|
+
if device.to_s != "cpu"
|
|
20
|
+
# Move to CPU to avoid Metal F32->F64 conversion issue
|
|
21
|
+
to_device(Candle::Device.cpu).values.each { |value| yield value }
|
|
22
|
+
else
|
|
23
|
+
values.each { |value| yield value }
|
|
24
|
+
end
|
|
25
|
+
# :nocov:
|
|
26
|
+
end
|
|
27
|
+
else
|
|
28
|
+
# For non-F32 dtypes, use regular values
|
|
29
|
+
values.each { |value| yield value }
|
|
30
|
+
end
|
|
31
|
+
else
|
|
32
|
+
# Multi-dimensional tensor - yield each sub-tensor
|
|
33
|
+
shape.first.times do |i|
|
|
34
|
+
yield self[i]
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
# Convert scalar tensor to float
|
|
40
|
+
def to_f
|
|
41
|
+
if rank == 0
|
|
42
|
+
# Use item method which handles dtype conversion properly
|
|
43
|
+
item
|
|
44
|
+
else
|
|
45
|
+
raise ArgumentError, "to_f can only be called on scalar tensors (rank 0), but this tensor has rank #{rank}"
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
# Convert scalar tensor to integer
|
|
50
|
+
def to_i
|
|
51
|
+
to_f.to_i
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Improved inspect method showing shape, dtype, and device
|
|
55
|
+
def inspect
|
|
56
|
+
shape_str = shape.join("x")
|
|
57
|
+
|
|
58
|
+
parts = ["#<Candle::Tensor"]
|
|
59
|
+
parts << "shape=#{shape_str}"
|
|
60
|
+
parts << "dtype=#{dtype}"
|
|
61
|
+
parts << "device=#{device}"
|
|
62
|
+
|
|
63
|
+
# Add element count for clarity
|
|
64
|
+
parts << "elements=#{elem_count}"
|
|
65
|
+
|
|
66
|
+
parts.join(" ") + ">"
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Override class methods to support keyword arguments for device
|
|
71
|
+
class << self
|
|
72
|
+
alias_method :_original_new, :new
|
|
73
|
+
alias_method :_original_ones, :ones
|
|
74
|
+
alias_method :_original_zeros, :zeros
|
|
75
|
+
alias_method :_original_rand, :rand
|
|
76
|
+
alias_method :_original_randn, :randn
|
|
77
|
+
|
|
78
|
+
def new(data, dtype = nil, device: nil)
|
|
79
|
+
_original_new(data, dtype, device)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
def ones(shape, device: nil)
|
|
83
|
+
_original_ones(shape, device)
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def zeros(shape, device: nil)
|
|
87
|
+
_original_zeros(shape, device)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def rand(shape, device: nil)
|
|
91
|
+
_original_rand(shape, device)
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
def randn(shape, device: nil)
|
|
95
|
+
_original_randn(shape, device)
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
end
|