informers 0.2.0 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +63 -99
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +28 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
@@ -0,0 +1,439 @@
|
|
1
|
+
module Informers
|
2
|
+
class Pipeline
|
3
|
+
def initialize(task:, model:, tokenizer: nil, processor: nil)
|
4
|
+
super()
|
5
|
+
@task = task
|
6
|
+
@model = model
|
7
|
+
@tokenizer = tokenizer
|
8
|
+
@processor = processor
|
9
|
+
end
|
10
|
+
end
|
11
|
+
|
12
|
+
class TextClassificationPipeline < Pipeline
|
13
|
+
def initialize(**options)
|
14
|
+
super(**options)
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(texts, top_k: 1)
|
18
|
+
# Run tokenization
|
19
|
+
model_inputs = @tokenizer.(texts,
|
20
|
+
padding: true,
|
21
|
+
truncation: true
|
22
|
+
)
|
23
|
+
|
24
|
+
# Run model
|
25
|
+
outputs = @model.(model_inputs)
|
26
|
+
|
27
|
+
function_to_apply =
|
28
|
+
if @model.config.problem_type == "multi_label_classification"
|
29
|
+
->(batch) { Utils.sigmoid(batch) }
|
30
|
+
else
|
31
|
+
->(batch) { Utils.softmax(batch) } # single_label_classification (default)
|
32
|
+
end
|
33
|
+
|
34
|
+
id2label = @model.config.id2label
|
35
|
+
|
36
|
+
to_return = []
|
37
|
+
outputs.logits.each do |batch|
|
38
|
+
output = function_to_apply.(batch)
|
39
|
+
scores = Utils.get_top_items(output, top_k)
|
40
|
+
|
41
|
+
vals = scores.map do |x|
|
42
|
+
{
|
43
|
+
label: id2label[x[0].to_s],
|
44
|
+
score: x[1]
|
45
|
+
}
|
46
|
+
end
|
47
|
+
if top_k == 1
|
48
|
+
to_return.concat(vals)
|
49
|
+
else
|
50
|
+
to_return << vals
|
51
|
+
end
|
52
|
+
end
|
53
|
+
|
54
|
+
texts.is_a?(Array) ? to_return : to_return[0]
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
class TokenClassificationPipeline < Pipeline
|
59
|
+
def initialize(**options)
|
60
|
+
super(**options)
|
61
|
+
end
|
62
|
+
|
63
|
+
def call(
|
64
|
+
texts,
|
65
|
+
ignore_labels: ["O"],
|
66
|
+
aggregation_strategy: "simple"
|
67
|
+
)
|
68
|
+
is_batched = texts.is_a?(Array)
|
69
|
+
|
70
|
+
# Run tokenization
|
71
|
+
model_inputs = @tokenizer.(is_batched ? texts : [texts],
|
72
|
+
padding: true,
|
73
|
+
truncation: true,
|
74
|
+
return_offsets: true
|
75
|
+
)
|
76
|
+
|
77
|
+
# Run model
|
78
|
+
outputs = @model.(model_inputs)
|
79
|
+
|
80
|
+
logits = outputs.logits
|
81
|
+
id2label = @model.config.id2label
|
82
|
+
|
83
|
+
to_return = []
|
84
|
+
logits.length.times do |i|
|
85
|
+
ids = model_inputs[:input_ids][i]
|
86
|
+
batch = logits[i]
|
87
|
+
offsets = model_inputs[:offsets][i]
|
88
|
+
|
89
|
+
# List of tokens that aren't ignored
|
90
|
+
tokens = []
|
91
|
+
batch.length.times do |j|
|
92
|
+
token_data = batch[j]
|
93
|
+
top_score_index = Utils.max(token_data)[1]
|
94
|
+
|
95
|
+
entity = id2label ? id2label[top_score_index.to_s] : "LABEL_#{top_score_index}"
|
96
|
+
if ignore_labels.include?(entity)
|
97
|
+
# We predicted a token that should be ignored. So, we skip it.
|
98
|
+
next
|
99
|
+
end
|
100
|
+
|
101
|
+
# TODO add option to keep special tokens?
|
102
|
+
word = @tokenizer.decode([ids[j]], skip_special_tokens: true)
|
103
|
+
if word == ""
|
104
|
+
# Was a special token. So, we skip it.
|
105
|
+
next
|
106
|
+
end
|
107
|
+
|
108
|
+
scores = Utils.softmax(token_data)
|
109
|
+
|
110
|
+
tokens << {
|
111
|
+
entity: entity,
|
112
|
+
score: scores[top_score_index],
|
113
|
+
index: j,
|
114
|
+
word: word,
|
115
|
+
start: offsets[j][0],
|
116
|
+
end: offsets[j][1]
|
117
|
+
}
|
118
|
+
end
|
119
|
+
|
120
|
+
case aggregation_strategy
|
121
|
+
when "simple"
|
122
|
+
tokens = group_entities(tokens)
|
123
|
+
when "none"
|
124
|
+
# do nothing
|
125
|
+
else
|
126
|
+
raise ArgumentError, "Invalid aggregation_strategy"
|
127
|
+
end
|
128
|
+
|
129
|
+
to_return << tokens
|
130
|
+
end
|
131
|
+
is_batched ? to_return : to_return[0]
|
132
|
+
end
|
133
|
+
|
134
|
+
def group_sub_entities(entities)
|
135
|
+
# Get the first entity in the entity group
|
136
|
+
entity = entities[0][:entity].split("-", 2)[-1]
|
137
|
+
scores = entities.map { |entity| entity[:score] }
|
138
|
+
tokens = entities.map { |entity| entity[:word] }
|
139
|
+
|
140
|
+
entity_group = {
|
141
|
+
entity_group: entity,
|
142
|
+
score: scores.sum / scores.count.to_f,
|
143
|
+
word: @tokenizer.convert_tokens_to_string(tokens),
|
144
|
+
start: entities[0][:start],
|
145
|
+
end: entities[-1][:end]
|
146
|
+
}
|
147
|
+
entity_group
|
148
|
+
end
|
149
|
+
|
150
|
+
def get_tag(entity_name)
|
151
|
+
if entity_name.start_with?("B-")
|
152
|
+
bi = "B"
|
153
|
+
tag = entity_name[2..]
|
154
|
+
elsif entity_name.start_with?("I-")
|
155
|
+
bi = "I"
|
156
|
+
tag = entity_name[2..]
|
157
|
+
else
|
158
|
+
# It's not in B-, I- format
|
159
|
+
# Default to I- for continuation.
|
160
|
+
bi = "I"
|
161
|
+
tag = entity_name
|
162
|
+
end
|
163
|
+
[bi, tag]
|
164
|
+
end
|
165
|
+
|
166
|
+
def group_entities(entities)
|
167
|
+
entity_groups = []
|
168
|
+
entity_group_disagg = []
|
169
|
+
|
170
|
+
entities.each do |entity|
|
171
|
+
if entity_group_disagg.empty?
|
172
|
+
entity_group_disagg << entity
|
173
|
+
next
|
174
|
+
end
|
175
|
+
|
176
|
+
# If the current entity is similar and adjacent to the previous entity,
|
177
|
+
# append it to the disaggregated entity group
|
178
|
+
# The split is meant to account for the "B" and "I" prefixes
|
179
|
+
# Shouldn't merge if both entities are B-type
|
180
|
+
bi, tag = get_tag(entity[:entity])
|
181
|
+
_last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])
|
182
|
+
|
183
|
+
if tag == last_tag && bi != "B"
|
184
|
+
# Modify subword type to be previous_type
|
185
|
+
entity_group_disagg << entity
|
186
|
+
else
|
187
|
+
# If the current entity is different from the previous entity
|
188
|
+
# aggregate the disaggregated entity group
|
189
|
+
entity_groups << group_sub_entities(entity_group_disagg)
|
190
|
+
entity_group_disagg = [entity]
|
191
|
+
end
|
192
|
+
end
|
193
|
+
if entity_group_disagg.any?
|
194
|
+
# it's the last entity, add it to the entity groups
|
195
|
+
entity_groups << group_sub_entities(entity_group_disagg)
|
196
|
+
end
|
197
|
+
|
198
|
+
entity_groups
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
class QuestionAnsweringPipeline < Pipeline
|
203
|
+
def initialize(**options)
|
204
|
+
super(**options)
|
205
|
+
end
|
206
|
+
|
207
|
+
def call(question, context, top_k: 1)
|
208
|
+
# Run tokenization
|
209
|
+
inputs = @tokenizer.(question,
|
210
|
+
text_pair: context,
|
211
|
+
padding: true,
|
212
|
+
truncation: true,
|
213
|
+
return_offsets: true
|
214
|
+
)
|
215
|
+
|
216
|
+
output = @model.(inputs)
|
217
|
+
|
218
|
+
to_return = []
|
219
|
+
output.start_logits.length.times do |j|
|
220
|
+
ids = inputs[:input_ids][j]
|
221
|
+
sep_index = ids.index(@tokenizer.sep_token_id)
|
222
|
+
offsets = inputs[:offsets][j]
|
223
|
+
|
224
|
+
s1 = Utils.softmax(output.start_logits[j])
|
225
|
+
.map.with_index
|
226
|
+
.select { |x| x[1] > sep_index }
|
227
|
+
e1 = Utils.softmax(output.end_logits[j])
|
228
|
+
.map.with_index
|
229
|
+
.select { |x| x[1] > sep_index }
|
230
|
+
|
231
|
+
options = s1.product(e1)
|
232
|
+
.select { |x| x[0][1] <= x[1][1] }
|
233
|
+
.map { |x| [x[0][1], x[1][1], x[0][0] * x[1][0]] }
|
234
|
+
.sort_by { |v| -v[2] }
|
235
|
+
|
236
|
+
[options.length, top_k].min.times do |k|
|
237
|
+
start, end_, score = options[k]
|
238
|
+
|
239
|
+
answer_tokens = ids.slice(start, end_ + 1)
|
240
|
+
|
241
|
+
answer = @tokenizer.decode(answer_tokens,
|
242
|
+
skip_special_tokens: true
|
243
|
+
)
|
244
|
+
|
245
|
+
to_return << {
|
246
|
+
answer:,
|
247
|
+
score:,
|
248
|
+
start: offsets[start][0],
|
249
|
+
end: offsets[end_][1]
|
250
|
+
}
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
question.is_a?(Array) ? to_return : to_return[0]
|
255
|
+
end
|
256
|
+
end
|
257
|
+
|
258
|
+
class FeatureExtractionPipeline < Pipeline
|
259
|
+
def initialize(**options)
|
260
|
+
super(**options)
|
261
|
+
end
|
262
|
+
|
263
|
+
def call(
|
264
|
+
texts,
|
265
|
+
pooling: "none",
|
266
|
+
normalize: false,
|
267
|
+
quantize: false,
|
268
|
+
precision: "binary"
|
269
|
+
)
|
270
|
+
# Run tokenization
|
271
|
+
model_inputs = @tokenizer.(texts,
|
272
|
+
padding: true,
|
273
|
+
truncation: true
|
274
|
+
)
|
275
|
+
|
276
|
+
# Run model
|
277
|
+
outputs = @model.(model_inputs)
|
278
|
+
|
279
|
+
# TODO check outputs.last_hidden_state
|
280
|
+
result = outputs.logits
|
281
|
+
case pooling
|
282
|
+
when "none"
|
283
|
+
# Skip pooling
|
284
|
+
when "mean"
|
285
|
+
result = Utils.mean_pooling(result, model_inputs[:attention_mask])
|
286
|
+
when "cls"
|
287
|
+
result = result.map(&:first)
|
288
|
+
else
|
289
|
+
raise Error, "Pooling method '#{pooling}' not supported."
|
290
|
+
end
|
291
|
+
|
292
|
+
if normalize
|
293
|
+
result = Utils.normalize(result)
|
294
|
+
end
|
295
|
+
|
296
|
+
if quantize
|
297
|
+
result = quantize_embeddings(result, precision)
|
298
|
+
end
|
299
|
+
|
300
|
+
texts.is_a?(Array) ? result : result[0]
|
301
|
+
end
|
302
|
+
end
|
303
|
+
|
304
|
+
SUPPORTED_TASKS = {
|
305
|
+
"text-classification" => {
|
306
|
+
tokenizer: AutoTokenizer,
|
307
|
+
pipeline: TextClassificationPipeline,
|
308
|
+
model: AutoModelForSequenceClassification,
|
309
|
+
default: {
|
310
|
+
model: "Xenova/distilbert-base-uncased-finetuned-sst-2-english"
|
311
|
+
},
|
312
|
+
type: "text"
|
313
|
+
},
|
314
|
+
"token-classification" => {
|
315
|
+
tokenizer: AutoTokenizer,
|
316
|
+
pipeline: TokenClassificationPipeline,
|
317
|
+
model: AutoModelForTokenClassification,
|
318
|
+
default: {
|
319
|
+
model: "Xenova/bert-base-multilingual-cased-ner-hrl"
|
320
|
+
},
|
321
|
+
type: "text"
|
322
|
+
},
|
323
|
+
"question-answering" => {
|
324
|
+
tokenizer: AutoTokenizer,
|
325
|
+
pipeline: QuestionAnsweringPipeline,
|
326
|
+
model: AutoModelForQuestionAnswering,
|
327
|
+
default: {
|
328
|
+
model: "Xenova/distilbert-base-cased-distilled-squad"
|
329
|
+
},
|
330
|
+
type: "text"
|
331
|
+
},
|
332
|
+
"feature-extraction" => {
|
333
|
+
tokenizer: AutoTokenizer,
|
334
|
+
pipeline: FeatureExtractionPipeline,
|
335
|
+
model: AutoModel,
|
336
|
+
default: {
|
337
|
+
model: "Xenova/all-MiniLM-L6-v2"
|
338
|
+
},
|
339
|
+
type: "text"
|
340
|
+
}
|
341
|
+
}
|
342
|
+
|
343
|
+
TASK_ALIASES = {
|
344
|
+
"sentiment-analysis" => "text-classification",
|
345
|
+
"ner" => "token-classification"
|
346
|
+
}
|
347
|
+
|
348
|
+
DEFAULT_PROGRESS_CALLBACK = lambda do |msg|
|
349
|
+
stream = $stderr
|
350
|
+
tty = stream.tty?
|
351
|
+
width = tty ? stream.winsize[1] : 80
|
352
|
+
|
353
|
+
if msg[:status] == "progress" && tty
|
354
|
+
stream.print "\r#{Utils::Hub.display_progress(msg[:file], width, msg[:size], msg[:total_size])}"
|
355
|
+
elsif msg[:status] == "done" && !msg[:cache_hit]
|
356
|
+
if tty
|
357
|
+
stream.puts
|
358
|
+
else
|
359
|
+
stream.puts Utils::Hub.display_progress(msg[:file], width, 1, 1)
|
360
|
+
end
|
361
|
+
end
|
362
|
+
end
|
363
|
+
|
364
|
+
class << self
|
365
|
+
def pipeline(
|
366
|
+
task,
|
367
|
+
model = nil,
|
368
|
+
quantized: true,
|
369
|
+
progress_callback: DEFAULT_PROGRESS_CALLBACK,
|
370
|
+
config: nil,
|
371
|
+
cache_dir: nil,
|
372
|
+
local_files_only: false,
|
373
|
+
revision: "main",
|
374
|
+
model_file_name: nil
|
375
|
+
)
|
376
|
+
# Apply aliases
|
377
|
+
task = TASK_ALIASES[task] || task
|
378
|
+
|
379
|
+
# Get pipeline info
|
380
|
+
pipeline_info = SUPPORTED_TASKS[task.split("_", 1)[0]]
|
381
|
+
if !pipeline_info
|
382
|
+
raise Error, "Unsupported pipeline: #{task}. Must be one of #{SUPPORTED_TASKS.keys}"
|
383
|
+
end
|
384
|
+
|
385
|
+
# Use model if specified, otherwise, use default
|
386
|
+
if !model
|
387
|
+
model = pipeline_info[:default][:model]
|
388
|
+
warn "No model specified. Using default model: #{model.inspect}."
|
389
|
+
end
|
390
|
+
|
391
|
+
pretrained_options = {
|
392
|
+
quantized:,
|
393
|
+
progress_callback:,
|
394
|
+
config:,
|
395
|
+
cache_dir:,
|
396
|
+
local_files_only:,
|
397
|
+
revision:,
|
398
|
+
model_file_name:
|
399
|
+
}
|
400
|
+
|
401
|
+
classes = {
|
402
|
+
tokenizer: pipeline_info[:tokenizer],
|
403
|
+
model: pipeline_info[:model],
|
404
|
+
processor: pipeline_info[:processor]
|
405
|
+
}
|
406
|
+
|
407
|
+
# Load model, tokenizer, and processor (if they exist)
|
408
|
+
results = load_items(classes, model, pretrained_options)
|
409
|
+
results[:task] = task
|
410
|
+
|
411
|
+
Utils.dispatch_callback(progress_callback, {
|
412
|
+
status: "ready",
|
413
|
+
task: task,
|
414
|
+
model: model
|
415
|
+
})
|
416
|
+
|
417
|
+
pipeline_class = pipeline_info.fetch(:pipeline)
|
418
|
+
pipeline_class.new(**results)
|
419
|
+
end
|
420
|
+
|
421
|
+
private
|
422
|
+
|
423
|
+
def load_items(mapping, model, pretrained_options)
|
424
|
+
result = {}
|
425
|
+
|
426
|
+
mapping.each do |name, cls|
|
427
|
+
next if !cls
|
428
|
+
|
429
|
+
if cls.is_a?(Array)
|
430
|
+
raise Todo
|
431
|
+
else
|
432
|
+
result[name] = cls.from_pretrained(model, **pretrained_options)
|
433
|
+
end
|
434
|
+
end
|
435
|
+
|
436
|
+
result
|
437
|
+
end
|
438
|
+
end
|
439
|
+
end
|
@@ -0,0 +1,141 @@
|
|
1
|
+
module Informers
|
2
|
+
class PreTrainedTokenizer
|
3
|
+
attr_reader :sep_token_id
|
4
|
+
|
5
|
+
def initialize(tokenizer_json, tokenizer_config)
|
6
|
+
super()
|
7
|
+
|
8
|
+
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
|
9
|
+
|
10
|
+
@sep_token = tokenizer_config["sep_token"]
|
11
|
+
@sep_token_id = @tokenizer.token_to_id(@sep_token)
|
12
|
+
|
13
|
+
@model_max_length = tokenizer_config["model_max_length"]
|
14
|
+
end
|
15
|
+
|
16
|
+
def call(
|
17
|
+
text,
|
18
|
+
text_pair: nil,
|
19
|
+
add_special_tokens: true,
|
20
|
+
padding: false,
|
21
|
+
truncation: nil,
|
22
|
+
max_length: nil,
|
23
|
+
return_tensor: true,
|
24
|
+
return_token_type_ids: true, # TODO change default
|
25
|
+
return_offsets: false
|
26
|
+
)
|
27
|
+
is_batched = text.is_a?(Array)
|
28
|
+
|
29
|
+
if is_batched
|
30
|
+
if text.length == 0
|
31
|
+
raise Error, "text array must be non-empty"
|
32
|
+
end
|
33
|
+
|
34
|
+
if !text_pair.nil?
|
35
|
+
if !text_pair.is_a?(Array)
|
36
|
+
raise Error, "text_pair must also be an array"
|
37
|
+
elsif text.length != text_pair.length
|
38
|
+
raise Error, "text and text_pair must have the same length"
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
if padding
|
44
|
+
@tokenizer.enable_padding
|
45
|
+
else
|
46
|
+
@tokenizer.no_padding
|
47
|
+
end
|
48
|
+
|
49
|
+
if truncation
|
50
|
+
@tokenizer.enable_truncation(max_length || @model_max_length)
|
51
|
+
else
|
52
|
+
@tokenizer.no_truncation
|
53
|
+
end
|
54
|
+
|
55
|
+
if is_batched
|
56
|
+
input = text_pair ? text.zip(text_pair) : text
|
57
|
+
encoded = @tokenizer.encode_batch(input, add_special_tokens: add_special_tokens)
|
58
|
+
else
|
59
|
+
encoded = [@tokenizer.encode(text, text_pair, add_special_tokens: add_special_tokens)]
|
60
|
+
end
|
61
|
+
|
62
|
+
result = {input_ids: encoded.map(&:ids), attention_mask: encoded.map(&:attention_mask)}
|
63
|
+
if return_token_type_ids
|
64
|
+
result[:token_type_ids] = encoded.map(&:type_ids)
|
65
|
+
end
|
66
|
+
if return_offsets
|
67
|
+
result[:offsets] = encoded.map(&:offsets)
|
68
|
+
end
|
69
|
+
result
|
70
|
+
end
|
71
|
+
|
72
|
+
def decode(tokens, skip_special_tokens:)
|
73
|
+
@tokenizer.decode(tokens, skip_special_tokens: skip_special_tokens)
|
74
|
+
end
|
75
|
+
|
76
|
+
def convert_tokens_to_string(tokens)
|
77
|
+
@tokenizer.decoder.decode(tokens)
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
class BertTokenizer < PreTrainedTokenizer
|
82
|
+
# TODO
|
83
|
+
# self.return_token_type_ids = true
|
84
|
+
end
|
85
|
+
|
86
|
+
class DistilBertTokenizer < PreTrainedTokenizer
|
87
|
+
end
|
88
|
+
|
89
|
+
class AutoTokenizer
|
90
|
+
TOKENIZER_CLASS_MAPPING = {
|
91
|
+
"BertTokenizer" => BertTokenizer,
|
92
|
+
"DistilBertTokenizer" => DistilBertTokenizer
|
93
|
+
}
|
94
|
+
|
95
|
+
def self.from_pretrained(
|
96
|
+
pretrained_model_name_or_path,
|
97
|
+
quantized: true,
|
98
|
+
progress_callback: nil,
|
99
|
+
config: nil,
|
100
|
+
cache_dir: nil,
|
101
|
+
local_files_only: false,
|
102
|
+
revision: "main",
|
103
|
+
legacy: nil,
|
104
|
+
**kwargs
|
105
|
+
)
|
106
|
+
tokenizer_json, tokenizer_config = load_tokenizer(
|
107
|
+
pretrained_model_name_or_path,
|
108
|
+
quantized:,
|
109
|
+
progress_callback:,
|
110
|
+
config:,
|
111
|
+
cache_dir:,
|
112
|
+
local_files_only:,
|
113
|
+
revision:,
|
114
|
+
legacy:
|
115
|
+
)
|
116
|
+
|
117
|
+
# Some tokenizers are saved with the "Fast" suffix, so we remove that if present.
|
118
|
+
tokenizer_name = tokenizer_config["tokenizer_class"]&.delete_suffix("Fast") || "PreTrainedTokenizer"
|
119
|
+
|
120
|
+
cls = TOKENIZER_CLASS_MAPPING[tokenizer_name]
|
121
|
+
if !cls
|
122
|
+
warn "Unknown tokenizer class #{tokenizer_name.inspect}, attempting to construct from base class."
|
123
|
+
cls = PreTrainedTokenizer
|
124
|
+
end
|
125
|
+
cls.new(tokenizer_json, tokenizer_config)
|
126
|
+
end
|
127
|
+
|
128
|
+
def self.load_tokenizer(pretrained_model_name_or_path, **options)
|
129
|
+
info = [
|
130
|
+
Utils::Hub.get_model_file(pretrained_model_name_or_path, "tokenizer.json", true, **options),
|
131
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "tokenizer_config.json", true, **options),
|
132
|
+
]
|
133
|
+
|
134
|
+
# Override legacy option if `options.legacy` is not null
|
135
|
+
if !options[:legacy].nil?
|
136
|
+
info[1]["legacy"] = options[:legacy]
|
137
|
+
end
|
138
|
+
info
|
139
|
+
end
|
140
|
+
end
|
141
|
+
end
|