red-candle 1.0.0.pre.7 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +399 -18
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +5 -0
- data/ext/candle/src/llm/llama.rs +5 -0
- data/ext/candle/src/llm/mistral.rs +5 -0
- data/ext/candle/src/llm/mod.rs +1 -89
- data/ext/candle/src/llm/quantized_gguf.rs +5 -0
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -33
- data/ext/candle/src/ruby/llm.rs +31 -13
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +128 -5
- data/ext/candle/src/ruby/qtensor.rs +0 -69
data/lib/candle/build_info.rb
CHANGED
@@ -7,6 +7,7 @@ module Candle
|
|
7
7
|
return unless ENV['CANDLE_VERBOSE'] || ENV['CANDLE_DEBUG'] || $DEBUG
|
8
8
|
|
9
9
|
if info["cuda_available"] == false
|
10
|
+
# :nocov:
|
10
11
|
# Check if CUDA could be available on the system
|
11
12
|
cuda_potentially_available = ENV['CUDA_ROOT'] || ENV['CUDA_PATH'] ||
|
12
13
|
File.exist?('/usr/local/cuda') || File.exist?('/opt/cuda')
|
@@ -18,6 +19,7 @@ module Candle
|
|
18
19
|
warn " CANDLE_ENABLE_CUDA=1 gem install red-candle"
|
19
20
|
warn "=" * 80
|
20
21
|
end
|
22
|
+
# :nocov:
|
21
23
|
end
|
22
24
|
end
|
23
25
|
|
data/lib/candle/device_utils.rb
CHANGED
@@ -7,6 +7,7 @@ module Candle
|
|
7
7
|
# Try Metal first (for Mac users)
|
8
8
|
Device.metal
|
9
9
|
rescue
|
10
|
+
# :nocov:
|
10
11
|
begin
|
11
12
|
# Try CUDA next (for NVIDIA GPU users)
|
12
13
|
Device.cuda
|
@@ -14,6 +15,7 @@ module Candle
|
|
14
15
|
# Fall back to CPU
|
15
16
|
Device.cpu
|
16
17
|
end
|
18
|
+
# :nocov:
|
17
19
|
end
|
18
20
|
end
|
19
21
|
end
|
data/lib/candle/ner.rb
ADDED
@@ -0,0 +1,345 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Candle
|
4
|
+
# Named Entity Recognition (NER) for token classification
|
5
|
+
#
|
6
|
+
# This class provides methods to extract named entities from text using
|
7
|
+
# pre-trained BERT-based models. It supports standard NER labels like
|
8
|
+
# PER (person), ORG (organization), LOC (location), and can be extended
|
9
|
+
# with custom entity types.
|
10
|
+
#
|
11
|
+
# @example Load a pre-trained NER model
|
12
|
+
# ner = Candle::NER.from_pretrained("Babelscape/wikineural-multilingual-ner")
|
13
|
+
#
|
14
|
+
# @example Load a model with a specific tokenizer
|
15
|
+
# ner = Candle::NER.from_pretrained("dslim/bert-base-NER", tokenizer: "bert-base-cased")
|
16
|
+
#
|
17
|
+
# @example Extract entities from text
|
18
|
+
# entities = ner.extract_entities("Apple Inc. was founded by Steve Jobs in Cupertino.")
|
19
|
+
# # => [
|
20
|
+
# # { text: "Apple Inc.", label: "ORG", start: 0, end: 10, confidence: 0.99 },
|
21
|
+
# # { text: "Steve Jobs", label: "PER", start: 26, end: 36, confidence: 0.98 },
|
22
|
+
# # { text: "Cupertino", label: "LOC", start: 40, end: 49, confidence: 0.97 }
|
23
|
+
# # ]
|
24
|
+
#
|
25
|
+
# @example Get token-level predictions
|
26
|
+
# tokens = ner.predict_tokens("John works at Google")
|
27
|
+
# # Returns detailed token-by-token predictions with confidence scores
|
28
|
+
class NER
|
29
|
+
class << self
|
30
|
+
# Load a pre-trained NER model from HuggingFace
|
31
|
+
#
|
32
|
+
# @param model_id [String] HuggingFace model ID (e.g., "dslim/bert-base-NER")
|
33
|
+
# @param device [Device, nil] Device to run on (defaults to best available)
|
34
|
+
# @param tokenizer [String, nil] Tokenizer model ID to use (defaults to same as model_id)
|
35
|
+
# @return [NER] NER instance
|
36
|
+
def from_pretrained(model_id, device: nil, tokenizer: nil)
|
37
|
+
new(model_id, device, tokenizer)
|
38
|
+
end
|
39
|
+
|
40
|
+
# Popular pre-trained models for different domains
|
41
|
+
def suggested_models
|
42
|
+
{
|
43
|
+
general: {
|
44
|
+
model: "Babelscape/wikineural-multilingual-ner",
|
45
|
+
note: "Has tokenizer.json"
|
46
|
+
},
|
47
|
+
general_alt: {
|
48
|
+
model: "dslim/bert-base-NER",
|
49
|
+
tokenizer: "bert-base-cased",
|
50
|
+
note: "Requires separate tokenizer"
|
51
|
+
},
|
52
|
+
multilingual: {
|
53
|
+
model: "Davlan/bert-base-multilingual-cased-ner-hrl",
|
54
|
+
note: "Check tokenizer availability"
|
55
|
+
},
|
56
|
+
biomedical: {
|
57
|
+
model: "dmis-lab/biobert-base-cased-v1.2",
|
58
|
+
note: "May require specific tokenizer"
|
59
|
+
},
|
60
|
+
clinical: {
|
61
|
+
model: "emilyalsentzer/Bio_ClinicalBERT",
|
62
|
+
note: "May require specific tokenizer"
|
63
|
+
},
|
64
|
+
scientific: {
|
65
|
+
model: "allenai/scibert_scivocab_uncased",
|
66
|
+
note: "May require specific tokenizer"
|
67
|
+
}
|
68
|
+
}
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
# Create an alias for the native method
|
73
|
+
alias_method :_extract_entities, :extract_entities
|
74
|
+
|
75
|
+
# Extract entities from text
|
76
|
+
#
|
77
|
+
# @param text [String] The text to analyze
|
78
|
+
# @param confidence_threshold [Float] Minimum confidence score (default: 0.9)
|
79
|
+
# @return [Array<Hash>] Array of entity hashes with text, label, start, end, confidence
|
80
|
+
def extract_entities(text, confidence_threshold: 0.9)
|
81
|
+
# Call the native method with positional arguments
|
82
|
+
_extract_entities(text, confidence_threshold)
|
83
|
+
end
|
84
|
+
|
85
|
+
# Get available entity types
|
86
|
+
#
|
87
|
+
# @return [Array<String>] List of entity types (without B-/I- prefixes)
|
88
|
+
def entity_types
|
89
|
+
return @entity_types if @entity_types
|
90
|
+
|
91
|
+
label_config = labels
|
92
|
+
@entity_types = label_config["label2id"].keys
|
93
|
+
.reject { |l| l == "O" }
|
94
|
+
.map { |l| l.sub(/^[BI]-/, "") }
|
95
|
+
.uniq
|
96
|
+
.sort
|
97
|
+
end
|
98
|
+
|
99
|
+
# Check if model supports a specific entity type
|
100
|
+
#
|
101
|
+
# @param entity_type [String] Entity type to check (e.g., "GENE", "PER")
|
102
|
+
# @return [Boolean] Whether the model recognizes this entity type
|
103
|
+
def supports_entity?(entity_type)
|
104
|
+
entity_types.include?(entity_type.upcase)
|
105
|
+
end
|
106
|
+
|
107
|
+
# Extract entities of a specific type
|
108
|
+
#
|
109
|
+
# @param text [String] The text to analyze
|
110
|
+
# @param entity_type [String] Entity type to extract (e.g., "PER", "ORG")
|
111
|
+
# @param confidence_threshold [Float] Minimum confidence score
|
112
|
+
# @return [Array<Hash>] Filtered entities of the specified type
|
113
|
+
def extract_entity_type(text, entity_type, confidence_threshold: 0.9)
|
114
|
+
entities = extract_entities(text, confidence_threshold: confidence_threshold)
|
115
|
+
entities.select { |e| e["label"] == entity_type.upcase }
|
116
|
+
end
|
117
|
+
|
118
|
+
# Analyze text and return both entities and token predictions
|
119
|
+
#
|
120
|
+
# @param text [String] The text to analyze
|
121
|
+
# @param confidence_threshold [Float] Minimum confidence for entities
|
122
|
+
# @return [Hash] Hash with :entities and :tokens keys
|
123
|
+
def analyze(text, confidence_threshold: 0.9)
|
124
|
+
{
|
125
|
+
entities: extract_entities(text, confidence_threshold: confidence_threshold),
|
126
|
+
tokens: predict_tokens(text)
|
127
|
+
}
|
128
|
+
end
|
129
|
+
|
130
|
+
# Get a formatted string representation of entities
|
131
|
+
#
|
132
|
+
# @param text [String] The text to analyze
|
133
|
+
# @param confidence_threshold [Float] Minimum confidence score
|
134
|
+
# @return [String] Formatted output with entities highlighted
|
135
|
+
def format_entities(text, confidence_threshold: 0.9)
|
136
|
+
entities = extract_entities(text, confidence_threshold: confidence_threshold)
|
137
|
+
return text if entities.empty?
|
138
|
+
|
139
|
+
# Sort by start position (reverse for easier insertion)
|
140
|
+
entities.sort_by! { |e| -e["start"] }
|
141
|
+
|
142
|
+
result = text.dup
|
143
|
+
entities.each do |entity|
|
144
|
+
label = "[#{entity['label']}:#{entity['confidence'].round(2)}]"
|
145
|
+
result.insert(entity["end"], label)
|
146
|
+
end
|
147
|
+
|
148
|
+
result
|
149
|
+
end
|
150
|
+
|
151
|
+
# Get model information
|
152
|
+
#
|
153
|
+
# @return [String] Model description
|
154
|
+
def inspect
|
155
|
+
"#<Candle::NER #{model_info}>"
|
156
|
+
end
|
157
|
+
|
158
|
+
alias to_s inspect
|
159
|
+
end
|
160
|
+
|
161
|
+
# Pattern-based entity recognizer for custom entities
|
162
|
+
class PatternEntityRecognizer
|
163
|
+
attr_reader :patterns, :entity_type
|
164
|
+
|
165
|
+
def initialize(entity_type, patterns = [])
|
166
|
+
@entity_type = entity_type
|
167
|
+
@patterns = patterns
|
168
|
+
end
|
169
|
+
|
170
|
+
# Add a pattern (String or Regexp)
|
171
|
+
def add_pattern(pattern)
|
172
|
+
@patterns << pattern
|
173
|
+
self
|
174
|
+
end
|
175
|
+
|
176
|
+
# Recognize entities using patterns
|
177
|
+
def recognize(text, tokenizer = nil)
|
178
|
+
entities = []
|
179
|
+
|
180
|
+
@patterns.each do |pattern|
|
181
|
+
regex = pattern.is_a?(Regexp) ? pattern : Regexp.new(pattern)
|
182
|
+
|
183
|
+
text.scan(regex) do |match|
|
184
|
+
match_text = $&
|
185
|
+
match_start = $~.offset(0)[0]
|
186
|
+
match_end = $~.offset(0)[1]
|
187
|
+
|
188
|
+
entities << {
|
189
|
+
"text" => match_text,
|
190
|
+
"label" => @entity_type,
|
191
|
+
"start" => match_start,
|
192
|
+
"end" => match_end,
|
193
|
+
"confidence" => 1.0,
|
194
|
+
"source" => "pattern"
|
195
|
+
}
|
196
|
+
end
|
197
|
+
end
|
198
|
+
|
199
|
+
entities
|
200
|
+
end
|
201
|
+
end
|
202
|
+
|
203
|
+
# Gazetteer-based entity recognizer
|
204
|
+
class GazetteerEntityRecognizer
|
205
|
+
attr_reader :entity_type, :terms, :case_sensitive
|
206
|
+
|
207
|
+
def initialize(entity_type, terms = [], case_sensitive: false)
|
208
|
+
@entity_type = entity_type
|
209
|
+
@case_sensitive = case_sensitive
|
210
|
+
@terms = build_term_set(terms)
|
211
|
+
end
|
212
|
+
|
213
|
+
# Add terms to the gazetteer
|
214
|
+
def add_terms(terms)
|
215
|
+
terms = [terms] unless terms.is_a?(Array)
|
216
|
+
terms.each { |term| @terms.add(normalize_term(term)) }
|
217
|
+
self
|
218
|
+
end
|
219
|
+
|
220
|
+
# Load terms from file
|
221
|
+
def load_from_file(filepath)
|
222
|
+
File.readlines(filepath).each do |line|
|
223
|
+
term = line.strip
|
224
|
+
add_terms(term) unless term.empty? || term.start_with?("#")
|
225
|
+
end
|
226
|
+
self
|
227
|
+
end
|
228
|
+
|
229
|
+
# Recognize entities using the gazetteer
|
230
|
+
def recognize(text, tokenizer = nil)
|
231
|
+
entities = []
|
232
|
+
normalized_text = @case_sensitive ? text : text.downcase
|
233
|
+
|
234
|
+
@terms.each do |term|
|
235
|
+
pattern = @case_sensitive ? term : term.downcase
|
236
|
+
pos = 0
|
237
|
+
|
238
|
+
while (idx = normalized_text.index(pattern, pos))
|
239
|
+
# Check word boundaries
|
240
|
+
prev_char = idx > 0 ? text[idx - 1] : " "
|
241
|
+
next_char = idx + pattern.length < text.length ? text[idx + pattern.length] : " "
|
242
|
+
|
243
|
+
if word_boundary?(prev_char) && word_boundary?(next_char)
|
244
|
+
entities << {
|
245
|
+
"text" => text[idx, pattern.length],
|
246
|
+
"label" => @entity_type,
|
247
|
+
"start" => idx,
|
248
|
+
"end" => idx + pattern.length,
|
249
|
+
"confidence" => 1.0,
|
250
|
+
"source" => "gazetteer"
|
251
|
+
}
|
252
|
+
end
|
253
|
+
|
254
|
+
pos = idx + 1
|
255
|
+
end
|
256
|
+
end
|
257
|
+
|
258
|
+
entities
|
259
|
+
end
|
260
|
+
|
261
|
+
private
|
262
|
+
|
263
|
+
def build_term_set(terms)
|
264
|
+
Set.new(terms.map { |term| normalize_term(term) })
|
265
|
+
end
|
266
|
+
|
267
|
+
def normalize_term(term)
|
268
|
+
@case_sensitive ? term : term.downcase
|
269
|
+
end
|
270
|
+
|
271
|
+
def word_boundary?(char)
|
272
|
+
char.match?(/\W/)
|
273
|
+
end
|
274
|
+
end
|
275
|
+
|
276
|
+
# Hybrid NER that combines ML model with rules
|
277
|
+
class HybridNER
|
278
|
+
attr_reader :model_ner, :pattern_recognizers, :gazetteer_recognizers
|
279
|
+
|
280
|
+
def initialize(model_id = nil, device: nil)
|
281
|
+
@model_ner = model_id ? NER.from_pretrained(model_id, device: device) : nil
|
282
|
+
@pattern_recognizers = []
|
283
|
+
@gazetteer_recognizers = []
|
284
|
+
end
|
285
|
+
|
286
|
+
# Add a pattern-based recognizer
|
287
|
+
def add_pattern_recognizer(entity_type, patterns)
|
288
|
+
recognizer = PatternEntityRecognizer.new(entity_type, patterns)
|
289
|
+
@pattern_recognizers << recognizer
|
290
|
+
self
|
291
|
+
end
|
292
|
+
|
293
|
+
# Add a gazetteer-based recognizer
|
294
|
+
def add_gazetteer_recognizer(entity_type, terms, **options)
|
295
|
+
recognizer = GazetteerEntityRecognizer.new(entity_type, terms, **options)
|
296
|
+
@gazetteer_recognizers << recognizer
|
297
|
+
self
|
298
|
+
end
|
299
|
+
|
300
|
+
# Extract entities using all recognizers
|
301
|
+
def extract_entities(text, confidence_threshold: 0.9)
|
302
|
+
all_entities = []
|
303
|
+
|
304
|
+
# Model-based entities
|
305
|
+
if @model_ner
|
306
|
+
model_entities = @model_ner.extract_entities(text, confidence_threshold: confidence_threshold)
|
307
|
+
all_entities.concat(model_entities)
|
308
|
+
end
|
309
|
+
|
310
|
+
# Pattern-based entities
|
311
|
+
@pattern_recognizers.each do |recognizer|
|
312
|
+
pattern_entities = recognizer.recognize(text)
|
313
|
+
all_entities.concat(pattern_entities)
|
314
|
+
end
|
315
|
+
|
316
|
+
# Gazetteer-based entities
|
317
|
+
@gazetteer_recognizers.each do |recognizer|
|
318
|
+
gazetteer_entities = recognizer.recognize(text)
|
319
|
+
all_entities.concat(gazetteer_entities)
|
320
|
+
end
|
321
|
+
|
322
|
+
# Merge overlapping entities (prefer highest confidence)
|
323
|
+
merge_entities(all_entities)
|
324
|
+
end
|
325
|
+
|
326
|
+
private
|
327
|
+
|
328
|
+
def merge_entities(entities)
|
329
|
+
# Sort by start position and confidence (descending)
|
330
|
+
sorted = entities.sort_by { |e| [e["start"], -e["confidence"]] }
|
331
|
+
|
332
|
+
merged = []
|
333
|
+
sorted.each do |entity|
|
334
|
+
# Check if entity overlaps with any already merged
|
335
|
+
overlaps = merged.any? do |existing|
|
336
|
+
entity["start"] < existing["end"] && entity["end"] > existing["start"]
|
337
|
+
end
|
338
|
+
|
339
|
+
merged << entity unless overlaps
|
340
|
+
end
|
341
|
+
|
342
|
+
merged.sort_by { |e| e["start"] }
|
343
|
+
end
|
344
|
+
end
|
345
|
+
end
|
data/lib/candle/reranker.rb
CHANGED
@@ -10,7 +10,7 @@ module Candle
|
|
10
10
|
_create(model_path, device)
|
11
11
|
end
|
12
12
|
|
13
|
-
# Returns
|
13
|
+
# Returns documents ranked by relevance using the specified pooling method.
|
14
14
|
# @param query [String] The input text
|
15
15
|
# @param documents [Array<String>] The list of documents to compare against
|
16
16
|
# @param pooling_method [String] Pooling method: "pooler", "cls", or "mean". Default: "pooler"
|
data/lib/candle/tensor.rb
CHANGED
@@ -14,6 +14,7 @@ module Candle
|
|
14
14
|
begin
|
15
15
|
values_f32.each { |value| yield value }
|
16
16
|
rescue NoMethodError
|
17
|
+
# :nocov:
|
17
18
|
# If values_f32 isn't available yet (not recompiled), fall back
|
18
19
|
if device.to_s != "cpu"
|
19
20
|
# Move to CPU to avoid Metal F32->F64 conversion issue
|
@@ -21,6 +22,7 @@ module Candle
|
|
21
22
|
else
|
22
23
|
values.each { |value| yield value }
|
23
24
|
end
|
25
|
+
# :nocov:
|
24
26
|
end
|
25
27
|
else
|
26
28
|
# For non-F32 dtypes, use regular values
|
@@ -0,0 +1,139 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
module Candle
|
4
|
+
# Tokenizer class for text tokenization
|
5
|
+
#
|
6
|
+
# This class provides methods to encode text into tokens and decode tokens back to text.
|
7
|
+
# It supports both single text and batch processing, with options for special tokens,
|
8
|
+
# padding, and truncation.
|
9
|
+
#
|
10
|
+
# @example Create a tokenizer from a pretrained model
|
11
|
+
# tokenizer = Candle::Tokenizer.from_pretrained("bert-base-uncased")
|
12
|
+
#
|
13
|
+
# @example Encode and decode text
|
14
|
+
# tokens = tokenizer.encode("Hello, world!")
|
15
|
+
# text = tokenizer.decode(tokens)
|
16
|
+
#
|
17
|
+
# @example Batch encoding
|
18
|
+
# texts = ["Hello", "World", "Test"]
|
19
|
+
# batch_tokens = tokenizer.encode_batch(texts)
|
20
|
+
#
|
21
|
+
# @example Configure padding and truncation
|
22
|
+
# padded_tokenizer = tokenizer.with_padding(length: 128)
|
23
|
+
# truncated_tokenizer = tokenizer.with_truncation(512)
|
24
|
+
class Tokenizer
|
25
|
+
# These methods are implemented in Rust
|
26
|
+
# - from_file(path) - Load tokenizer from a JSON file
|
27
|
+
# - from_pretrained(model_id) - Load tokenizer from HuggingFace
|
28
|
+
# - encode(text, add_special_tokens = true) - Encode text to token IDs
|
29
|
+
# - encode_to_tokens(text, add_special_tokens = true) - Encode text to token strings
|
30
|
+
# - encode_with_tokens(text, add_special_tokens = true) - Get both IDs and tokens
|
31
|
+
# - encode_batch(texts, add_special_tokens = true) - Encode multiple texts to IDs
|
32
|
+
# - encode_batch_to_tokens(texts, add_special_tokens = true) - Encode multiple texts to tokens
|
33
|
+
# - decode(token_ids, skip_special_tokens = true) - Decode token IDs to text
|
34
|
+
# - id_to_token(token_id) - Get token string for a token ID
|
35
|
+
# - get_vocab(with_added_tokens = true) - Get vocabulary as hash
|
36
|
+
# - vocab_size(with_added_tokens = true) - Get vocabulary size
|
37
|
+
# - with_padding(options) - Create tokenizer with padding enabled
|
38
|
+
# - with_truncation(max_length) - Create tokenizer with truncation enabled
|
39
|
+
# - get_special_tokens - Get special tokens info
|
40
|
+
# - inspect - String representation
|
41
|
+
# - to_s - String representation
|
42
|
+
|
43
|
+
# The native methods accept positional arguments, but we provide keyword argument interfaces
|
44
|
+
# for better Ruby ergonomics. We need to call the native methods with positional args.
|
45
|
+
|
46
|
+
alias_method :_native_encode, :encode
|
47
|
+
alias_method :_native_encode_to_tokens, :encode_to_tokens
|
48
|
+
alias_method :_native_encode_with_tokens, :encode_with_tokens
|
49
|
+
alias_method :_native_encode_batch, :encode_batch
|
50
|
+
alias_method :_native_encode_batch_to_tokens, :encode_batch_to_tokens
|
51
|
+
alias_method :_native_decode, :decode
|
52
|
+
alias_method :_native_get_vocab, :get_vocab
|
53
|
+
alias_method :_native_vocab_size, :vocab_size
|
54
|
+
alias_method :_native_with_padding, :with_padding
|
55
|
+
|
56
|
+
# Encode text with convenient keyword arguments
|
57
|
+
#
|
58
|
+
# @param text [String] The text to encode
|
59
|
+
# @param add_special_tokens [Boolean] Whether to add special tokens (default: true)
|
60
|
+
# @return [Array<Integer>] Token IDs
|
61
|
+
def encode(text, add_special_tokens: true)
|
62
|
+
_native_encode(text, add_special_tokens)
|
63
|
+
end
|
64
|
+
|
65
|
+
# Encode text into token strings (words/subwords)
|
66
|
+
#
|
67
|
+
# @param text [String] The text to encode
|
68
|
+
# @param add_special_tokens [Boolean] Whether to add special tokens (default: true)
|
69
|
+
# @return [Array<String>] Token strings
|
70
|
+
def encode_to_tokens(text, add_special_tokens: true)
|
71
|
+
_native_encode_to_tokens(text, add_special_tokens)
|
72
|
+
end
|
73
|
+
|
74
|
+
# Encode text and return both IDs and token strings
|
75
|
+
#
|
76
|
+
# @param text [String] The text to encode
|
77
|
+
# @param add_special_tokens [Boolean] Whether to add special tokens (default: true)
|
78
|
+
# @return [Hash] Hash with :ids and :tokens arrays
|
79
|
+
def encode_with_tokens(text, add_special_tokens: true)
|
80
|
+
_native_encode_with_tokens(text, add_special_tokens)
|
81
|
+
end
|
82
|
+
|
83
|
+
# Encode multiple texts with convenient keyword arguments
|
84
|
+
#
|
85
|
+
# @param texts [Array<String>] The texts to encode
|
86
|
+
# @param add_special_tokens [Boolean] Whether to add special tokens (default: true)
|
87
|
+
# @return [Array<Array<Integer>>] Arrays of token IDs
|
88
|
+
def encode_batch(texts, add_special_tokens: true)
|
89
|
+
_native_encode_batch(texts, add_special_tokens)
|
90
|
+
end
|
91
|
+
|
92
|
+
# Encode multiple texts into token strings
|
93
|
+
#
|
94
|
+
# @param texts [Array<String>] The texts to encode
|
95
|
+
# @param add_special_tokens [Boolean] Whether to add special tokens (default: true)
|
96
|
+
# @return [Array<Array<String>>] Arrays of token strings
|
97
|
+
def encode_batch_to_tokens(texts, add_special_tokens: true)
|
98
|
+
_native_encode_batch_to_tokens(texts, add_special_tokens)
|
99
|
+
end
|
100
|
+
|
101
|
+
# Decode token IDs with convenient keyword arguments
|
102
|
+
#
|
103
|
+
# @param token_ids [Array<Integer>] The token IDs to decode
|
104
|
+
# @param skip_special_tokens [Boolean] Whether to skip special tokens (default: true)
|
105
|
+
# @return [String] Decoded text
|
106
|
+
def decode(token_ids, skip_special_tokens: true)
|
107
|
+
_native_decode(token_ids, skip_special_tokens)
|
108
|
+
end
|
109
|
+
|
110
|
+
# Get vocabulary with convenient keyword arguments
|
111
|
+
#
|
112
|
+
# @param with_added_tokens [Boolean] Include added tokens (default: true)
|
113
|
+
# @return [Hash<String, Integer>] Token to ID mapping
|
114
|
+
def get_vocab(with_added_tokens: true)
|
115
|
+
_native_get_vocab(with_added_tokens)
|
116
|
+
end
|
117
|
+
|
118
|
+
# Get vocabulary size with convenient keyword arguments
|
119
|
+
#
|
120
|
+
# @param with_added_tokens [Boolean] Include added tokens (default: true)
|
121
|
+
# @return [Integer] Vocabulary size
|
122
|
+
def vocab_size(with_added_tokens: true)
|
123
|
+
_native_vocab_size(with_added_tokens)
|
124
|
+
end
|
125
|
+
|
126
|
+
# Create a new tokenizer with padding configuration
|
127
|
+
#
|
128
|
+
# @param options [Hash] Padding options
|
129
|
+
# @option options [Integer] :length Fixed length padding
|
130
|
+
# @option options [Boolean] :max_length Use batch longest padding
|
131
|
+
# @option options [String] :direction Padding direction ("left" or "right")
|
132
|
+
# @option options [Integer] :pad_id Padding token ID
|
133
|
+
# @option options [String] :pad_token Padding token string
|
134
|
+
# @return [Tokenizer] New tokenizer instance with padding enabled
|
135
|
+
def with_padding(**options)
|
136
|
+
_native_with_padding(options)
|
137
|
+
end
|
138
|
+
end
|
139
|
+
end
|
data/lib/candle/version.rb
CHANGED
data/lib/candle.rb
CHANGED