informers 0.2.0 → 1.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +70 -95
- data/lib/informers/configs.rb +48 -0
- data/lib/informers/env.rb +14 -0
- data/lib/informers/model.rb +31 -0
- data/lib/informers/models.rb +294 -0
- data/lib/informers/pipelines.rb +439 -0
- data/lib/informers/tokenizers.rb +141 -0
- data/lib/informers/utils/core.rb +7 -0
- data/lib/informers/utils/hub.rb +240 -0
- data/lib/informers/utils/math.rb +44 -0
- data/lib/informers/utils/tensor.rb +26 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +29 -9
- metadata +21 -41
- data/lib/informers/feature_extraction.rb +0 -59
- data/lib/informers/fill_mask.rb +0 -109
- data/lib/informers/ner.rb +0 -106
- data/lib/informers/question_answering.rb +0 -197
- data/lib/informers/sentiment_analysis.rb +0 -72
- data/lib/informers/text_generation.rb +0 -54
- data/vendor/LICENSE-bert.txt +0 -202
- data/vendor/LICENSE-blingfire.txt +0 -21
- data/vendor/LICENSE-gpt2.txt +0 -24
- data/vendor/LICENSE-roberta.txt +0 -21
- data/vendor/bert_base_cased_tok.bin +0 -0
- data/vendor/bert_base_tok.bin +0 -0
- data/vendor/gpt2.bin +0 -0
- data/vendor/gpt2.i2w +0 -0
- data/vendor/roberta.bin +0 -0
- data/vendor/roberta.i2w +0 -0
@@ -0,0 +1,439 @@
|
|
1
|
+
module Informers
|
2
|
+
class Pipeline
|
3
|
+
def initialize(task:, model:, tokenizer: nil, processor: nil)
|
4
|
+
super()
|
5
|
+
@task = task
|
6
|
+
@model = model
|
7
|
+
@tokenizer = tokenizer
|
8
|
+
@processor = processor
|
9
|
+
end
|
10
|
+
end
|
11
|
+
|
12
|
+
class TextClassificationPipeline < Pipeline
|
13
|
+
def initialize(**options)
|
14
|
+
super(**options)
|
15
|
+
end
|
16
|
+
|
17
|
+
def call(texts, top_k: 1)
|
18
|
+
# Run tokenization
|
19
|
+
model_inputs = @tokenizer.(texts,
|
20
|
+
padding: true,
|
21
|
+
truncation: true
|
22
|
+
)
|
23
|
+
|
24
|
+
# Run model
|
25
|
+
outputs = @model.(model_inputs)
|
26
|
+
|
27
|
+
function_to_apply =
|
28
|
+
if @model.config.problem_type == "multi_label_classification"
|
29
|
+
->(batch) { Utils.sigmoid(batch) }
|
30
|
+
else
|
31
|
+
->(batch) { Utils.softmax(batch) } # single_label_classification (default)
|
32
|
+
end
|
33
|
+
|
34
|
+
id2label = @model.config.id2label
|
35
|
+
|
36
|
+
to_return = []
|
37
|
+
outputs.logits.each do |batch|
|
38
|
+
output = function_to_apply.(batch)
|
39
|
+
scores = Utils.get_top_items(output, top_k)
|
40
|
+
|
41
|
+
vals = scores.map do |x|
|
42
|
+
{
|
43
|
+
label: id2label[x[0].to_s],
|
44
|
+
score: x[1]
|
45
|
+
}
|
46
|
+
end
|
47
|
+
if top_k == 1
|
48
|
+
to_return.concat(vals)
|
49
|
+
else
|
50
|
+
to_return << vals
|
51
|
+
end
|
52
|
+
end
|
53
|
+
|
54
|
+
texts.is_a?(Array) ? to_return : to_return[0]
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
class TokenClassificationPipeline < Pipeline
|
59
|
+
def initialize(**options)
|
60
|
+
super(**options)
|
61
|
+
end
|
62
|
+
|
63
|
+
def call(
|
64
|
+
texts,
|
65
|
+
ignore_labels: ["O"],
|
66
|
+
aggregation_strategy: "simple"
|
67
|
+
)
|
68
|
+
is_batched = texts.is_a?(Array)
|
69
|
+
|
70
|
+
# Run tokenization
|
71
|
+
model_inputs = @tokenizer.(is_batched ? texts : [texts],
|
72
|
+
padding: true,
|
73
|
+
truncation: true,
|
74
|
+
return_offsets: true
|
75
|
+
)
|
76
|
+
|
77
|
+
# Run model
|
78
|
+
outputs = @model.(model_inputs)
|
79
|
+
|
80
|
+
logits = outputs.logits
|
81
|
+
id2label = @model.config.id2label
|
82
|
+
|
83
|
+
to_return = []
|
84
|
+
logits.length.times do |i|
|
85
|
+
ids = model_inputs[:input_ids][i]
|
86
|
+
batch = logits[i]
|
87
|
+
offsets = model_inputs[:offsets][i]
|
88
|
+
|
89
|
+
# List of tokens that aren't ignored
|
90
|
+
tokens = []
|
91
|
+
batch.length.times do |j|
|
92
|
+
token_data = batch[j]
|
93
|
+
top_score_index = Utils.max(token_data)[1]
|
94
|
+
|
95
|
+
entity = id2label ? id2label[top_score_index.to_s] : "LABEL_#{top_score_index}"
|
96
|
+
if ignore_labels.include?(entity)
|
97
|
+
# We predicted a token that should be ignored. So, we skip it.
|
98
|
+
next
|
99
|
+
end
|
100
|
+
|
101
|
+
# TODO add option to keep special tokens?
|
102
|
+
word = @tokenizer.decode([ids[j]], skip_special_tokens: true)
|
103
|
+
if word == ""
|
104
|
+
# Was a special token. So, we skip it.
|
105
|
+
next
|
106
|
+
end
|
107
|
+
|
108
|
+
scores = Utils.softmax(token_data)
|
109
|
+
|
110
|
+
tokens << {
|
111
|
+
entity: entity,
|
112
|
+
score: scores[top_score_index],
|
113
|
+
index: j,
|
114
|
+
word: word,
|
115
|
+
start: offsets[j][0],
|
116
|
+
end: offsets[j][1]
|
117
|
+
}
|
118
|
+
end
|
119
|
+
|
120
|
+
case aggregation_strategy
|
121
|
+
when "simple"
|
122
|
+
tokens = group_entities(tokens)
|
123
|
+
when "none"
|
124
|
+
# do nothing
|
125
|
+
else
|
126
|
+
raise ArgumentError, "Invalid aggregation_strategy"
|
127
|
+
end
|
128
|
+
|
129
|
+
to_return << tokens
|
130
|
+
end
|
131
|
+
is_batched ? to_return : to_return[0]
|
132
|
+
end
|
133
|
+
|
134
|
+
def group_sub_entities(entities)
|
135
|
+
# Get the first entity in the entity group
|
136
|
+
entity = entities[0][:entity].split("-", 2)[-1]
|
137
|
+
scores = entities.map { |entity| entity[:score] }
|
138
|
+
tokens = entities.map { |entity| entity[:word] }
|
139
|
+
|
140
|
+
entity_group = {
|
141
|
+
entity_group: entity,
|
142
|
+
score: scores.sum / scores.count.to_f,
|
143
|
+
word: @tokenizer.convert_tokens_to_string(tokens),
|
144
|
+
start: entities[0][:start],
|
145
|
+
end: entities[-1][:end]
|
146
|
+
}
|
147
|
+
entity_group
|
148
|
+
end
|
149
|
+
|
150
|
+
def get_tag(entity_name)
|
151
|
+
if entity_name.start_with?("B-")
|
152
|
+
bi = "B"
|
153
|
+
tag = entity_name[2..]
|
154
|
+
elsif entity_name.start_with?("I-")
|
155
|
+
bi = "I"
|
156
|
+
tag = entity_name[2..]
|
157
|
+
else
|
158
|
+
# It's not in B-, I- format
|
159
|
+
# Default to I- for continuation.
|
160
|
+
bi = "I"
|
161
|
+
tag = entity_name
|
162
|
+
end
|
163
|
+
[bi, tag]
|
164
|
+
end
|
165
|
+
|
166
|
+
def group_entities(entities)
|
167
|
+
entity_groups = []
|
168
|
+
entity_group_disagg = []
|
169
|
+
|
170
|
+
entities.each do |entity|
|
171
|
+
if entity_group_disagg.empty?
|
172
|
+
entity_group_disagg << entity
|
173
|
+
next
|
174
|
+
end
|
175
|
+
|
176
|
+
# If the current entity is similar and adjacent to the previous entity,
|
177
|
+
# append it to the disaggregated entity group
|
178
|
+
# The split is meant to account for the "B" and "I" prefixes
|
179
|
+
# Shouldn't merge if both entities are B-type
|
180
|
+
bi, tag = get_tag(entity[:entity])
|
181
|
+
_last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])
|
182
|
+
|
183
|
+
if tag == last_tag && bi != "B"
|
184
|
+
# Modify subword type to be previous_type
|
185
|
+
entity_group_disagg << entity
|
186
|
+
else
|
187
|
+
# If the current entity is different from the previous entity
|
188
|
+
# aggregate the disaggregated entity group
|
189
|
+
entity_groups << group_sub_entities(entity_group_disagg)
|
190
|
+
entity_group_disagg = [entity]
|
191
|
+
end
|
192
|
+
end
|
193
|
+
if entity_group_disagg.any?
|
194
|
+
# it's the last entity, add it to the entity groups
|
195
|
+
entity_groups << group_sub_entities(entity_group_disagg)
|
196
|
+
end
|
197
|
+
|
198
|
+
entity_groups
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
class QuestionAnsweringPipeline < Pipeline
|
203
|
+
def initialize(**options)
|
204
|
+
super(**options)
|
205
|
+
end
|
206
|
+
|
207
|
+
def call(question, context, top_k: 1)
|
208
|
+
# Run tokenization
|
209
|
+
inputs = @tokenizer.(question,
|
210
|
+
text_pair: context,
|
211
|
+
padding: true,
|
212
|
+
truncation: true,
|
213
|
+
return_offsets: true
|
214
|
+
)
|
215
|
+
|
216
|
+
output = @model.(inputs)
|
217
|
+
|
218
|
+
to_return = []
|
219
|
+
output.start_logits.length.times do |j|
|
220
|
+
ids = inputs[:input_ids][j]
|
221
|
+
sep_index = ids.index(@tokenizer.sep_token_id)
|
222
|
+
offsets = inputs[:offsets][j]
|
223
|
+
|
224
|
+
s1 = Utils.softmax(output.start_logits[j])
|
225
|
+
.map.with_index
|
226
|
+
.select { |x| x[1] > sep_index }
|
227
|
+
e1 = Utils.softmax(output.end_logits[j])
|
228
|
+
.map.with_index
|
229
|
+
.select { |x| x[1] > sep_index }
|
230
|
+
|
231
|
+
options = s1.product(e1)
|
232
|
+
.select { |x| x[0][1] <= x[1][1] }
|
233
|
+
.map { |x| [x[0][1], x[1][1], x[0][0] * x[1][0]] }
|
234
|
+
.sort_by { |v| -v[2] }
|
235
|
+
|
236
|
+
[options.length, top_k].min.times do |k|
|
237
|
+
start, end_, score = options[k]
|
238
|
+
|
239
|
+
answer_tokens = ids.slice(start, end_ + 1)
|
240
|
+
|
241
|
+
answer = @tokenizer.decode(answer_tokens,
|
242
|
+
skip_special_tokens: true
|
243
|
+
)
|
244
|
+
|
245
|
+
to_return << {
|
246
|
+
answer:,
|
247
|
+
score:,
|
248
|
+
start: offsets[start][0],
|
249
|
+
end: offsets[end_][1]
|
250
|
+
}
|
251
|
+
end
|
252
|
+
end
|
253
|
+
|
254
|
+
question.is_a?(Array) ? to_return : to_return[0]
|
255
|
+
end
|
256
|
+
end
|
257
|
+
|
258
|
+
class FeatureExtractionPipeline < Pipeline
|
259
|
+
def initialize(**options)
|
260
|
+
super(**options)
|
261
|
+
end
|
262
|
+
|
263
|
+
def call(
|
264
|
+
texts,
|
265
|
+
pooling: "none",
|
266
|
+
normalize: false,
|
267
|
+
quantize: false,
|
268
|
+
precision: "binary"
|
269
|
+
)
|
270
|
+
# Run tokenization
|
271
|
+
model_inputs = @tokenizer.(texts,
|
272
|
+
padding: true,
|
273
|
+
truncation: true
|
274
|
+
)
|
275
|
+
|
276
|
+
# Run model
|
277
|
+
outputs = @model.(model_inputs)
|
278
|
+
|
279
|
+
# TODO check outputs.last_hidden_state
|
280
|
+
result = outputs.logits
|
281
|
+
case pooling
|
282
|
+
when "none"
|
283
|
+
# Skip pooling
|
284
|
+
when "mean"
|
285
|
+
result = Utils.mean_pooling(result, model_inputs[:attention_mask])
|
286
|
+
when "cls"
|
287
|
+
result = result.map(&:first)
|
288
|
+
else
|
289
|
+
raise Error, "Pooling method '#{pooling}' not supported."
|
290
|
+
end
|
291
|
+
|
292
|
+
if normalize
|
293
|
+
result = Utils.normalize(result)
|
294
|
+
end
|
295
|
+
|
296
|
+
if quantize
|
297
|
+
result = quantize_embeddings(result, precision)
|
298
|
+
end
|
299
|
+
|
300
|
+
texts.is_a?(Array) ? result : result[0]
|
301
|
+
end
|
302
|
+
end
|
303
|
+
|
304
|
+
SUPPORTED_TASKS = {
|
305
|
+
"text-classification" => {
|
306
|
+
tokenizer: AutoTokenizer,
|
307
|
+
pipeline: TextClassificationPipeline,
|
308
|
+
model: AutoModelForSequenceClassification,
|
309
|
+
default: {
|
310
|
+
model: "Xenova/distilbert-base-uncased-finetuned-sst-2-english"
|
311
|
+
},
|
312
|
+
type: "text"
|
313
|
+
},
|
314
|
+
"token-classification" => {
|
315
|
+
tokenizer: AutoTokenizer,
|
316
|
+
pipeline: TokenClassificationPipeline,
|
317
|
+
model: AutoModelForTokenClassification,
|
318
|
+
default: {
|
319
|
+
model: "Xenova/bert-base-multilingual-cased-ner-hrl"
|
320
|
+
},
|
321
|
+
type: "text"
|
322
|
+
},
|
323
|
+
"question-answering" => {
|
324
|
+
tokenizer: AutoTokenizer,
|
325
|
+
pipeline: QuestionAnsweringPipeline,
|
326
|
+
model: AutoModelForQuestionAnswering,
|
327
|
+
default: {
|
328
|
+
model: "Xenova/distilbert-base-cased-distilled-squad"
|
329
|
+
},
|
330
|
+
type: "text"
|
331
|
+
},
|
332
|
+
"feature-extraction" => {
|
333
|
+
tokenizer: AutoTokenizer,
|
334
|
+
pipeline: FeatureExtractionPipeline,
|
335
|
+
model: AutoModel,
|
336
|
+
default: {
|
337
|
+
model: "Xenova/all-MiniLM-L6-v2"
|
338
|
+
},
|
339
|
+
type: "text"
|
340
|
+
}
|
341
|
+
}
|
342
|
+
|
343
|
+
TASK_ALIASES = {
|
344
|
+
"sentiment-analysis" => "text-classification",
|
345
|
+
"ner" => "token-classification"
|
346
|
+
}
|
347
|
+
|
348
|
+
DEFAULT_PROGRESS_CALLBACK = lambda do |msg|
|
349
|
+
stream = $stderr
|
350
|
+
tty = stream.tty?
|
351
|
+
width = tty ? stream.winsize[1] : 80
|
352
|
+
|
353
|
+
if msg[:status] == "progress" && tty
|
354
|
+
stream.print "\r#{Utils::Hub.display_progress(msg[:file], width, msg[:size], msg[:total_size])}"
|
355
|
+
elsif msg[:status] == "done" && !msg[:cache_hit]
|
356
|
+
if tty
|
357
|
+
stream.puts
|
358
|
+
else
|
359
|
+
stream.puts Utils::Hub.display_progress(msg[:file], width, 1, 1)
|
360
|
+
end
|
361
|
+
end
|
362
|
+
end
|
363
|
+
|
364
|
+
class << self
|
365
|
+
def pipeline(
|
366
|
+
task,
|
367
|
+
model = nil,
|
368
|
+
quantized: true,
|
369
|
+
progress_callback: DEFAULT_PROGRESS_CALLBACK,
|
370
|
+
config: nil,
|
371
|
+
cache_dir: nil,
|
372
|
+
local_files_only: false,
|
373
|
+
revision: "main",
|
374
|
+
model_file_name: nil
|
375
|
+
)
|
376
|
+
# Apply aliases
|
377
|
+
task = TASK_ALIASES[task] || task
|
378
|
+
|
379
|
+
# Get pipeline info
|
380
|
+
pipeline_info = SUPPORTED_TASKS[task.split("_", 1)[0]]
|
381
|
+
if !pipeline_info
|
382
|
+
raise Error, "Unsupported pipeline: #{task}. Must be one of #{SUPPORTED_TASKS.keys}"
|
383
|
+
end
|
384
|
+
|
385
|
+
# Use model if specified, otherwise, use default
|
386
|
+
if !model
|
387
|
+
model = pipeline_info[:default][:model]
|
388
|
+
warn "No model specified. Using default model: #{model.inspect}."
|
389
|
+
end
|
390
|
+
|
391
|
+
pretrained_options = {
|
392
|
+
quantized:,
|
393
|
+
progress_callback:,
|
394
|
+
config:,
|
395
|
+
cache_dir:,
|
396
|
+
local_files_only:,
|
397
|
+
revision:,
|
398
|
+
model_file_name:
|
399
|
+
}
|
400
|
+
|
401
|
+
classes = {
|
402
|
+
tokenizer: pipeline_info[:tokenizer],
|
403
|
+
model: pipeline_info[:model],
|
404
|
+
processor: pipeline_info[:processor]
|
405
|
+
}
|
406
|
+
|
407
|
+
# Load model, tokenizer, and processor (if they exist)
|
408
|
+
results = load_items(classes, model, pretrained_options)
|
409
|
+
results[:task] = task
|
410
|
+
|
411
|
+
Utils.dispatch_callback(progress_callback, {
|
412
|
+
status: "ready",
|
413
|
+
task: task,
|
414
|
+
model: model
|
415
|
+
})
|
416
|
+
|
417
|
+
pipeline_class = pipeline_info.fetch(:pipeline)
|
418
|
+
pipeline_class.new(**results)
|
419
|
+
end
|
420
|
+
|
421
|
+
private
|
422
|
+
|
423
|
+
def load_items(mapping, model, pretrained_options)
|
424
|
+
result = {}
|
425
|
+
|
426
|
+
mapping.each do |name, cls|
|
427
|
+
next if !cls
|
428
|
+
|
429
|
+
if cls.is_a?(Array)
|
430
|
+
raise Todo
|
431
|
+
else
|
432
|
+
result[name] = cls.from_pretrained(model, **pretrained_options)
|
433
|
+
end
|
434
|
+
end
|
435
|
+
|
436
|
+
result
|
437
|
+
end
|
438
|
+
end
|
439
|
+
end
|
@@ -0,0 +1,141 @@
|
|
1
|
+
module Informers
|
2
|
+
class PreTrainedTokenizer
|
3
|
+
attr_reader :sep_token_id
|
4
|
+
|
5
|
+
def initialize(tokenizer_json, tokenizer_config)
|
6
|
+
super()
|
7
|
+
|
8
|
+
@tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
|
9
|
+
|
10
|
+
@sep_token = tokenizer_config["sep_token"]
|
11
|
+
@sep_token_id = @tokenizer.token_to_id(@sep_token)
|
12
|
+
|
13
|
+
@model_max_length = tokenizer_config["model_max_length"]
|
14
|
+
end
|
15
|
+
|
16
|
+
def call(
|
17
|
+
text,
|
18
|
+
text_pair: nil,
|
19
|
+
add_special_tokens: true,
|
20
|
+
padding: false,
|
21
|
+
truncation: nil,
|
22
|
+
max_length: nil,
|
23
|
+
return_tensor: true,
|
24
|
+
return_token_type_ids: true, # TODO change default
|
25
|
+
return_offsets: false
|
26
|
+
)
|
27
|
+
is_batched = text.is_a?(Array)
|
28
|
+
|
29
|
+
if is_batched
|
30
|
+
if text.length == 0
|
31
|
+
raise Error, "text array must be non-empty"
|
32
|
+
end
|
33
|
+
|
34
|
+
if !text_pair.nil?
|
35
|
+
if !text_pair.is_a?(Array)
|
36
|
+
raise Error, "text_pair must also be an array"
|
37
|
+
elsif text.length != text_pair.length
|
38
|
+
raise Error, "text and text_pair must have the same length"
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
if padding
|
44
|
+
@tokenizer.enable_padding
|
45
|
+
else
|
46
|
+
@tokenizer.no_padding
|
47
|
+
end
|
48
|
+
|
49
|
+
if truncation
|
50
|
+
@tokenizer.enable_truncation(max_length || @model_max_length)
|
51
|
+
else
|
52
|
+
@tokenizer.no_truncation
|
53
|
+
end
|
54
|
+
|
55
|
+
if is_batched
|
56
|
+
input = text_pair ? text.zip(text_pair) : text
|
57
|
+
encoded = @tokenizer.encode_batch(input, add_special_tokens: add_special_tokens)
|
58
|
+
else
|
59
|
+
encoded = [@tokenizer.encode(text, text_pair, add_special_tokens: add_special_tokens)]
|
60
|
+
end
|
61
|
+
|
62
|
+
result = {input_ids: encoded.map(&:ids), attention_mask: encoded.map(&:attention_mask)}
|
63
|
+
if return_token_type_ids
|
64
|
+
result[:token_type_ids] = encoded.map(&:type_ids)
|
65
|
+
end
|
66
|
+
if return_offsets
|
67
|
+
result[:offsets] = encoded.map(&:offsets)
|
68
|
+
end
|
69
|
+
result
|
70
|
+
end
|
71
|
+
|
72
|
+
def decode(tokens, skip_special_tokens:)
|
73
|
+
@tokenizer.decode(tokens, skip_special_tokens: skip_special_tokens)
|
74
|
+
end
|
75
|
+
|
76
|
+
def convert_tokens_to_string(tokens)
|
77
|
+
@tokenizer.decoder.decode(tokens)
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
class BertTokenizer < PreTrainedTokenizer
|
82
|
+
# TODO
|
83
|
+
# self.return_token_type_ids = true
|
84
|
+
end
|
85
|
+
|
86
|
+
class DistilBertTokenizer < PreTrainedTokenizer
|
87
|
+
end
|
88
|
+
|
89
|
+
class AutoTokenizer
|
90
|
+
TOKENIZER_CLASS_MAPPING = {
|
91
|
+
"BertTokenizer" => BertTokenizer,
|
92
|
+
"DistilBertTokenizer" => DistilBertTokenizer
|
93
|
+
}
|
94
|
+
|
95
|
+
def self.from_pretrained(
|
96
|
+
pretrained_model_name_or_path,
|
97
|
+
quantized: true,
|
98
|
+
progress_callback: nil,
|
99
|
+
config: nil,
|
100
|
+
cache_dir: nil,
|
101
|
+
local_files_only: false,
|
102
|
+
revision: "main",
|
103
|
+
legacy: nil,
|
104
|
+
**kwargs
|
105
|
+
)
|
106
|
+
tokenizer_json, tokenizer_config = load_tokenizer(
|
107
|
+
pretrained_model_name_or_path,
|
108
|
+
quantized:,
|
109
|
+
progress_callback:,
|
110
|
+
config:,
|
111
|
+
cache_dir:,
|
112
|
+
local_files_only:,
|
113
|
+
revision:,
|
114
|
+
legacy:
|
115
|
+
)
|
116
|
+
|
117
|
+
# Some tokenizers are saved with the "Fast" suffix, so we remove that if present.
|
118
|
+
tokenizer_name = tokenizer_config["tokenizer_class"]&.delete_suffix("Fast") || "PreTrainedTokenizer"
|
119
|
+
|
120
|
+
cls = TOKENIZER_CLASS_MAPPING[tokenizer_name]
|
121
|
+
if !cls
|
122
|
+
warn "Unknown tokenizer class #{tokenizer_name.inspect}, attempting to construct from base class."
|
123
|
+
cls = PreTrainedTokenizer
|
124
|
+
end
|
125
|
+
cls.new(tokenizer_json, tokenizer_config)
|
126
|
+
end
|
127
|
+
|
128
|
+
def self.load_tokenizer(pretrained_model_name_or_path, **options)
|
129
|
+
info = [
|
130
|
+
Utils::Hub.get_model_file(pretrained_model_name_or_path, "tokenizer.json", true, **options),
|
131
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "tokenizer_config.json", true, **options),
|
132
|
+
]
|
133
|
+
|
134
|
+
# Override legacy option if `options.legacy` is not null
|
135
|
+
if !options[:legacy].nil?
|
136
|
+
info[1]["legacy"] = options[:legacy]
|
137
|
+
end
|
138
|
+
info
|
139
|
+
end
|
140
|
+
end
|
141
|
+
end
|