transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,282 @@
|
|
1
|
+
module Transformers
|
2
|
+
class TokenClassificationArgumentHandler < ArgumentHandler
|
3
|
+
end
|
4
|
+
|
5
|
+
class AggregationStrategy < ExplicitEnum
|
6
|
+
NONE = "none"
|
7
|
+
SIMPLE = "simple"
|
8
|
+
FIRST = "first"
|
9
|
+
AVERAGE = "average"
|
10
|
+
MAX = "max"
|
11
|
+
end
|
12
|
+
|
13
|
+
class TokenClassificationPipeline < ChunkPipeline
|
14
|
+
extend ClassAttribute
|
15
|
+
|
16
|
+
class_attribute :default_input_names, "sequences"
|
17
|
+
|
18
|
+
def initialize(*args, args_parser: TokenClassificationArgumentHandler.new, **kwargs)
|
19
|
+
super(*args, **kwargs)
|
20
|
+
check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
|
21
|
+
|
22
|
+
@basic_tokenizer = Bert::BertTokenizer::BasicTokenizer.new(do_lower_case: false)
|
23
|
+
@args_parser = args_parser
|
24
|
+
end
|
25
|
+
|
26
|
+
def _sanitize_parameters(
|
27
|
+
ignore_labels: nil,
|
28
|
+
grouped_entities: nil,
|
29
|
+
ignore_subwords: nil,
|
30
|
+
aggregation_strategy: nil,
|
31
|
+
offset_mapping: nil,
|
32
|
+
stride: nil
|
33
|
+
)
|
34
|
+
preprocess_params = {}
|
35
|
+
if !offset_mapping.nil?
|
36
|
+
preprocess_params[:offset_mapping] = offset_mapping
|
37
|
+
end
|
38
|
+
|
39
|
+
postprocess_params = {}
|
40
|
+
if !grouped_entities.nil? || !ignore_subwords.nil?
|
41
|
+
if grouped_entities && ignore_subwords
|
42
|
+
aggregation_strategy = AggregationStrategy::FIRST
|
43
|
+
elsif grouped_entities && !ignore_subwords
|
44
|
+
aggregation_strategy = AggregationStrategy::SIMPLE
|
45
|
+
else
|
46
|
+
aggregation_strategy = AggregationStrategy::NONE
|
47
|
+
end
|
48
|
+
|
49
|
+
if !grouped_entities.nil?
|
50
|
+
warn(
|
51
|
+
"`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to" +
|
52
|
+
" `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
|
53
|
+
)
|
54
|
+
end
|
55
|
+
if !ignore_subwords.nil?
|
56
|
+
warn(
|
57
|
+
"`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to" +
|
58
|
+
" `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
|
59
|
+
)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
if !aggregation_strategy.nil?
|
64
|
+
if aggregation_strategy.is_a?(String)
|
65
|
+
aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
|
66
|
+
end
|
67
|
+
if (
|
68
|
+
[AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
|
69
|
+
!@tokenizer.is_fast
|
70
|
+
)
|
71
|
+
raise ArgumentError,
|
72
|
+
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" +
|
73
|
+
' to `"simple"` or use a fast tokenizer.'
|
74
|
+
end
|
75
|
+
postprocess_params[:aggregation_strategy] = aggregation_strategy
|
76
|
+
end
|
77
|
+
if !ignore_labels.nil?
|
78
|
+
postprocess_params[:ignore_labels] = ignore_labels
|
79
|
+
end
|
80
|
+
if !stride.nil?
|
81
|
+
if stride >= @tokenizer.model_max_length
|
82
|
+
raise ArgumentError,
|
83
|
+
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
|
84
|
+
end
|
85
|
+
if aggregation_strategy == AggregationStrategy::NONE
|
86
|
+
raise ArgumentError,
|
87
|
+
"`stride` was provided to process all the text but `aggregation_strategy=" +
|
88
|
+
"\"#{aggregation_strategy}\"`, please select another one instead."
|
89
|
+
else
|
90
|
+
if @tokenizer.is_fast
|
91
|
+
tokenizer_params = {
|
92
|
+
return_overflowing_tokens: true,
|
93
|
+
padding: true,
|
94
|
+
stride: stride
|
95
|
+
}
|
96
|
+
preprocess_params[:tokenizer_params] = tokenizer_params
|
97
|
+
else
|
98
|
+
raise ArgumentError,
|
99
|
+
"`stride` was provided to process all the text but you're using a slow tokenizer." +
|
100
|
+
" Please use a fast tokenizer."
|
101
|
+
end
|
102
|
+
end
|
103
|
+
end
|
104
|
+
[preprocess_params, {}, postprocess_params]
|
105
|
+
end
|
106
|
+
|
107
|
+
def preprocess(sentence, offset_mapping: nil, **preprocess_params)
|
108
|
+
tokenizer_params = preprocess_params.delete(:tokenizer_params) { {} }
|
109
|
+
truncation = @tokenizer.model_max_length && @tokenizer.model_max_length > 0
|
110
|
+
inputs = @tokenizer.(
|
111
|
+
sentence,
|
112
|
+
return_tensors: @framework,
|
113
|
+
truncation: truncation,
|
114
|
+
return_special_tokens_mask: true,
|
115
|
+
return_offsets_mapping: @tokenizer.is_fast,
|
116
|
+
**tokenizer_params
|
117
|
+
)
|
118
|
+
inputs.delete(:overflow_to_sample_mapping)
|
119
|
+
num_chunks = inputs[:input_ids].length
|
120
|
+
|
121
|
+
num_chunks.times do |i|
|
122
|
+
if @framework == "tf"
|
123
|
+
raise Todo
|
124
|
+
else
|
125
|
+
model_inputs = inputs.to_h { |k, v| [k, v[i].unsqueeze(0)] }
|
126
|
+
end
|
127
|
+
if !@offset_mapping.nil?
|
128
|
+
model_inputs[:offset_mapping] = offset_mapping
|
129
|
+
end
|
130
|
+
model_inputs[:sentence] = i == 0 ? sentence : nil
|
131
|
+
model_inputs[:is_last] = (i == num_chunks - 1)
|
132
|
+
|
133
|
+
yield model_inputs
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
def _forward(model_inputs)
|
138
|
+
# Forward
|
139
|
+
special_tokens_mask = model_inputs.delete(:special_tokens_mask)
|
140
|
+
offset_mapping = model_inputs.delete(:offset_mapping)
|
141
|
+
sentence = model_inputs.delete(:sentence)
|
142
|
+
is_last = model_inputs.delete(:is_last)
|
143
|
+
if @framework == "tf"
|
144
|
+
logits = @model.(**model_inputs)[0]
|
145
|
+
else
|
146
|
+
output = @model.(**model_inputs)
|
147
|
+
logits = output.is_a?(Hash) ? output[:logits] : output[0]
|
148
|
+
end
|
149
|
+
|
150
|
+
{
|
151
|
+
logits: logits,
|
152
|
+
special_tokens_mask: special_tokens_mask,
|
153
|
+
offset_mapping: offset_mapping,
|
154
|
+
sentence: sentence,
|
155
|
+
is_last: is_last,
|
156
|
+
**model_inputs
|
157
|
+
}
|
158
|
+
end
|
159
|
+
|
160
|
+
def postprocess(all_outputs, aggregation_strategy: AggregationStrategy::NONE, ignore_labels: nil)
|
161
|
+
if ignore_labels.nil?
|
162
|
+
ignore_labels = ["O"]
|
163
|
+
end
|
164
|
+
all_entities = []
|
165
|
+
all_outputs.each do |model_outputs|
|
166
|
+
logits = model_outputs[:logits][0].numo
|
167
|
+
sentence = all_outputs[0][:sentence]
|
168
|
+
input_ids = model_outputs[:input_ids][0]
|
169
|
+
offset_mapping = (
|
170
|
+
!model_outputs[:offset_mapping].nil? ? model_outputs[:offset_mapping][0] : nil
|
171
|
+
)
|
172
|
+
special_tokens_mask = model_outputs[:special_tokens_mask][0].numo
|
173
|
+
|
174
|
+
maxes = logits.max(axis: -1).expand_dims(-1)
|
175
|
+
shifted_exp = Numo::NMath.exp(logits - maxes)
|
176
|
+
scores = shifted_exp / shifted_exp.sum(axis: -1).expand_dims(-1)
|
177
|
+
|
178
|
+
if @framework == "tf"
|
179
|
+
raise Todo
|
180
|
+
end
|
181
|
+
|
182
|
+
pre_entities = gather_pre_entities(
|
183
|
+
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
|
184
|
+
)
|
185
|
+
grouped_entities = aggregate(pre_entities, aggregation_strategy)
|
186
|
+
# Filter anything that is in self.ignore_labels
|
187
|
+
entities =
|
188
|
+
grouped_entities.select do |entity|
|
189
|
+
!ignore_labels.include?(entity[:entity]) && !ignore_labels.include?(entity[:entity_group])
|
190
|
+
end
|
191
|
+
all_entities.concat(entities)
|
192
|
+
end
|
193
|
+
num_chunks = all_outputs.length
|
194
|
+
if num_chunks > 1
|
195
|
+
all_entities = aggregate_overlapping_entities(all_entities)
|
196
|
+
end
|
197
|
+
all_entities
|
198
|
+
end
|
199
|
+
|
200
|
+
def gather_pre_entities(
|
201
|
+
sentence,
|
202
|
+
input_ids,
|
203
|
+
scores,
|
204
|
+
offset_mapping,
|
205
|
+
special_tokens_mask,
|
206
|
+
aggregation_strategy
|
207
|
+
)
|
208
|
+
pre_entities = []
|
209
|
+
scores.each_over_axis(0).with_index do |token_scores, idx|
|
210
|
+
# Filter special_tokens
|
211
|
+
if special_tokens_mask[idx] != 0
|
212
|
+
next
|
213
|
+
end
|
214
|
+
|
215
|
+
word = @tokenizer.convert_ids_to_tokens(input_ids[idx].to_i)
|
216
|
+
if !offset_mapping.nil?
|
217
|
+
start_ind, end_ind = offset_mapping[idx].to_a
|
218
|
+
if !start_ind.is_a?(Integer)
|
219
|
+
if @framework == "pt"
|
220
|
+
start_ind = start_ind.item
|
221
|
+
end_ind = end_ind.item
|
222
|
+
end
|
223
|
+
end
|
224
|
+
word_ref = sentence[start_ind...end_ind]
|
225
|
+
if @tokenizer.instance_variable_get(:@tokenizer).respond_to?(:continuing_subword_prefix)
|
226
|
+
# This is a BPE, word aware tokenizer, there is a correct way
|
227
|
+
# to fuse tokens
|
228
|
+
is_subword = word.length != word_ref.length
|
229
|
+
else
|
230
|
+
is_subword = start_ind > 0 && !sentence[(start_ind - 1)...(start_ind + 1)].include?(" ")
|
231
|
+
end
|
232
|
+
|
233
|
+
if input_ids[idx].to_i == @tokenizer.unk_token_id
|
234
|
+
word = word_ref
|
235
|
+
is_subword = false
|
236
|
+
end
|
237
|
+
else
|
238
|
+
start_ind = nil
|
239
|
+
end_ind = nil
|
240
|
+
is_subword = nil
|
241
|
+
end
|
242
|
+
|
243
|
+
pre_entity = {
|
244
|
+
word: word,
|
245
|
+
scores: token_scores,
|
246
|
+
start: start_ind,
|
247
|
+
end: end_ind,
|
248
|
+
index: idx,
|
249
|
+
is_subword: is_subword
|
250
|
+
}
|
251
|
+
pre_entities << pre_entity
|
252
|
+
end
|
253
|
+
pre_entities
|
254
|
+
end
|
255
|
+
|
256
|
+
def aggregate(pre_entities, aggregation_strategy)
|
257
|
+
if [AggregationStrategy::NONE, AggregationStrategy::SIMPLE].include?(aggregation_strategy)
|
258
|
+
entities = []
|
259
|
+
pre_entities.each do |pre_entity|
|
260
|
+
entity_idx = pre_entity[:scores].argmax
|
261
|
+
score = pre_entity[:scores][entity_idx]
|
262
|
+
entity = {
|
263
|
+
entity: @model.config.id2label[entity_idx],
|
264
|
+
score: score,
|
265
|
+
index: pre_entity[:index],
|
266
|
+
word: pre_entity[:word],
|
267
|
+
start: pre_entity[:start],
|
268
|
+
end: pre_entity[:end]
|
269
|
+
}
|
270
|
+
entities << entity
|
271
|
+
end
|
272
|
+
else
|
273
|
+
entities = aggregate_words(pre_entities, aggregation_strategy)
|
274
|
+
end
|
275
|
+
|
276
|
+
if aggregation_strategy == AggregationStrategy::NONE
|
277
|
+
return entities
|
278
|
+
end
|
279
|
+
group_entities(entities)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Transformers
|
2
|
+
module ClassAttribute
|
3
|
+
def class_attribute(name, default = nil)
|
4
|
+
singleton_class.attr_writer name
|
5
|
+
var = "@#{name}"
|
6
|
+
instance_variable_set(var, default)
|
7
|
+
singleton_class.define_method(name) do
|
8
|
+
# ancestors includes current module
|
9
|
+
ancestors.find { |c| c.instance_variable_defined?(var) }.instance_variable_get(var)
|
10
|
+
end
|
11
|
+
define_method(name) do
|
12
|
+
self.class.send(name)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
module Copy
|
18
|
+
def self.deepcopy(value, memo = {})
|
19
|
+
key = value.object_id
|
20
|
+
if !memo.key?(key)
|
21
|
+
copy = value.dup
|
22
|
+
memo[key] = copy
|
23
|
+
if value.is_a?(Hash)
|
24
|
+
copy.transform_keys! { |k| deepcopy(k, memo) }
|
25
|
+
copy.transform_values! { |v| deepcopy(v, memo) }
|
26
|
+
elsif value.is_a?(Array)
|
27
|
+
copy.map! { |v| deepcopy(v, memo) }
|
28
|
+
end
|
29
|
+
end
|
30
|
+
memo[key]
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Transformers
|
2
|
+
class SentenceTransformer
|
3
|
+
def initialize(model_id)
|
4
|
+
@model_id = model_id
|
5
|
+
@tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
|
6
|
+
@model = Transformers::AutoModel.from_pretrained(model_id)
|
7
|
+
end
|
8
|
+
|
9
|
+
def encode(sentences)
|
10
|
+
singular = sentences.is_a?(String)
|
11
|
+
sentences = [sentences] if singular
|
12
|
+
|
13
|
+
input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
|
14
|
+
output = Torch.no_grad { @model.(**input) }[0]
|
15
|
+
|
16
|
+
# TODO check modules.json
|
17
|
+
if [
|
18
|
+
"sentence-transformers/all-MiniLM-L6-v2",
|
19
|
+
"sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
|
20
|
+
].include?(@model_id)
|
21
|
+
output = mean_pooling(output, input[:attention_mask])
|
22
|
+
output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
|
23
|
+
else
|
24
|
+
output = output[0.., 0].to_a
|
25
|
+
end
|
26
|
+
|
27
|
+
singular ? output[0] : output
|
28
|
+
end
|
29
|
+
|
30
|
+
private
|
31
|
+
|
32
|
+
def mean_pooling(output, attention_mask)
|
33
|
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
|
34
|
+
Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,152 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class PreTrainedTokenizer < PreTrainedTokenizerBase
|
17
|
+
def initialize(**kwargs)
|
18
|
+
|
19
|
+
# 2. init `_added_tokens_decoder` if child class did not
|
20
|
+
if !instance_variable_defined?(:@added_tokens_decoder)
|
21
|
+
@added_tokens_decoder = {}
|
22
|
+
end
|
23
|
+
|
24
|
+
# 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
|
25
|
+
@added_tokens_decoder.merge!(kwargs.delete(:added_tokens_decoder) { {} })
|
26
|
+
@added_tokens_encoder = @added_tokens_decoder.to_h { |k, v| [k.content, v] }
|
27
|
+
|
28
|
+
# 4 init the parent class
|
29
|
+
super(**kwargs)
|
30
|
+
end
|
31
|
+
|
32
|
+
def is_fast
|
33
|
+
false
|
34
|
+
end
|
35
|
+
|
36
|
+
def vocab_size
|
37
|
+
raise NotImplementedError
|
38
|
+
end
|
39
|
+
|
40
|
+
def tokenize(text, **kwargs)
|
41
|
+
raise Todo
|
42
|
+
end
|
43
|
+
|
44
|
+
def _encode_plus(
|
45
|
+
text:,
|
46
|
+
text_pair: nil,
|
47
|
+
add_special_tokens: true,
|
48
|
+
padding_strategy: PaddingStrategy::DO_NOT_PAD,
|
49
|
+
truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
|
50
|
+
max_length: nil,
|
51
|
+
stride: 0,
|
52
|
+
is_split_into_words: false,
|
53
|
+
pad_to_multiple_of: nil,
|
54
|
+
return_tensors: nil,
|
55
|
+
return_token_type_ids: nil,
|
56
|
+
return_attention_mask: nil,
|
57
|
+
return_overflowing_tokens: false,
|
58
|
+
return_special_tokens_mask: false,
|
59
|
+
return_offsets_mapping: false,
|
60
|
+
return_length: false,
|
61
|
+
verbose: true,
|
62
|
+
**kwargs
|
63
|
+
)
|
64
|
+
get_input_ids = lambda do |text|
|
65
|
+
if text.is_a?(String)
|
66
|
+
tokens = tokenize(text, **kwargs)
|
67
|
+
convert_tokens_to_ids(tokens)
|
68
|
+
elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(String)
|
69
|
+
if is_split_into_words
|
70
|
+
raise Todo
|
71
|
+
else
|
72
|
+
convert_tokens_to_ids(text)
|
73
|
+
end
|
74
|
+
elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(Integer)
|
75
|
+
text
|
76
|
+
else
|
77
|
+
if is_split_into_words
|
78
|
+
raise ArgumentError,
|
79
|
+
"Input #{text} is not valid. Should be a string or a list/tuple of strings when" +
|
80
|
+
" `is_split_into_words=True`."
|
81
|
+
else
|
82
|
+
raise ArgumentError,
|
83
|
+
"Input #{text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" +
|
84
|
+
" integers."
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
if return_offsets_mapping
|
90
|
+
raise RuntimeError,
|
91
|
+
"return_offset_mapping is not available when using Ruby tokenizers. " +
|
92
|
+
"To use this feature, change your tokenizer to one deriving from " +
|
93
|
+
"Transformers::PreTrainedTokenizerFast. " +
|
94
|
+
"More information on available tokenizers at " +
|
95
|
+
"https://github.com/huggingface/transformers/pull/2674"
|
96
|
+
end
|
97
|
+
|
98
|
+
first_ids = get_input_ids.(text)
|
99
|
+
second_ids = !text_pair.nil? ? get_input_ids.(text_pair) : nil
|
100
|
+
|
101
|
+
prepare_for_model(
|
102
|
+
first_ids,
|
103
|
+
pair_ids: second_ids,
|
104
|
+
add_special_tokens: add_special_tokens,
|
105
|
+
padding: padding_strategy,
|
106
|
+
truncation: truncation_strategy,
|
107
|
+
max_length: max_length,
|
108
|
+
stride: stride,
|
109
|
+
pad_to_multiple_of: pad_to_multiple_of,
|
110
|
+
return_tensors: return_tensors,
|
111
|
+
prepend_batch_axis: true,
|
112
|
+
return_attention_mask: return_attention_mask,
|
113
|
+
return_token_type_ids: return_token_type_ids,
|
114
|
+
return_overflowing_tokens: return_overflowing_tokens,
|
115
|
+
return_special_tokens_mask: return_special_tokens_mask,
|
116
|
+
return_length: return_length,
|
117
|
+
verbose: verbose
|
118
|
+
)
|
119
|
+
end
|
120
|
+
|
121
|
+
def convert_tokens_to_ids(tokens)
|
122
|
+
if tokens.nil?
|
123
|
+
return nil
|
124
|
+
end
|
125
|
+
|
126
|
+
if tokens.is_a?(String)
|
127
|
+
return _convert_token_to_id_with_added_voc(tokens)
|
128
|
+
end
|
129
|
+
|
130
|
+
ids = []
|
131
|
+
tokens.each do |token|
|
132
|
+
ids << _convert_token_to_id_with_added_voc(token)
|
133
|
+
end
|
134
|
+
ids
|
135
|
+
end
|
136
|
+
|
137
|
+
def _convert_token_to_id_with_added_voc(token)
|
138
|
+
if token.nil?
|
139
|
+
return nil
|
140
|
+
end
|
141
|
+
|
142
|
+
if @added_tokens_encoder.include?(token)
|
143
|
+
return @added_tokens_encoder[token]
|
144
|
+
end
|
145
|
+
_convert_token_to_id(token)
|
146
|
+
end
|
147
|
+
|
148
|
+
def _convert_token_to_id(token)
|
149
|
+
raise NotImplementedError
|
150
|
+
end
|
151
|
+
end
|
152
|
+
end
|