transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,282 @@
|
|
1
|
+
module Transformers
|
2
|
+
class TokenClassificationArgumentHandler < ArgumentHandler
|
3
|
+
end
|
4
|
+
|
5
|
+
class AggregationStrategy < ExplicitEnum
|
6
|
+
NONE = "none"
|
7
|
+
SIMPLE = "simple"
|
8
|
+
FIRST = "first"
|
9
|
+
AVERAGE = "average"
|
10
|
+
MAX = "max"
|
11
|
+
end
|
12
|
+
|
13
|
+
class TokenClassificationPipeline < ChunkPipeline
|
14
|
+
extend ClassAttribute
|
15
|
+
|
16
|
+
class_attribute :default_input_names, "sequences"
|
17
|
+
|
18
|
+
def initialize(*args, args_parser: TokenClassificationArgumentHandler.new, **kwargs)
|
19
|
+
super(*args, **kwargs)
|
20
|
+
check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
|
21
|
+
|
22
|
+
@basic_tokenizer = Bert::BertTokenizer::BasicTokenizer.new(do_lower_case: false)
|
23
|
+
@args_parser = args_parser
|
24
|
+
end
|
25
|
+
|
26
|
+
def _sanitize_parameters(
|
27
|
+
ignore_labels: nil,
|
28
|
+
grouped_entities: nil,
|
29
|
+
ignore_subwords: nil,
|
30
|
+
aggregation_strategy: nil,
|
31
|
+
offset_mapping: nil,
|
32
|
+
stride: nil
|
33
|
+
)
|
34
|
+
preprocess_params = {}
|
35
|
+
if !offset_mapping.nil?
|
36
|
+
preprocess_params[:offset_mapping] = offset_mapping
|
37
|
+
end
|
38
|
+
|
39
|
+
postprocess_params = {}
|
40
|
+
if !grouped_entities.nil? || !ignore_subwords.nil?
|
41
|
+
if grouped_entities && ignore_subwords
|
42
|
+
aggregation_strategy = AggregationStrategy::FIRST
|
43
|
+
elsif grouped_entities && !ignore_subwords
|
44
|
+
aggregation_strategy = AggregationStrategy::SIMPLE
|
45
|
+
else
|
46
|
+
aggregation_strategy = AggregationStrategy::NONE
|
47
|
+
end
|
48
|
+
|
49
|
+
if !grouped_entities.nil?
|
50
|
+
warn(
|
51
|
+
"`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to" +
|
52
|
+
" `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
|
53
|
+
)
|
54
|
+
end
|
55
|
+
if !ignore_subwords.nil?
|
56
|
+
warn(
|
57
|
+
"`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to" +
|
58
|
+
" `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
|
59
|
+
)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
if !aggregation_strategy.nil?
|
64
|
+
if aggregation_strategy.is_a?(String)
|
65
|
+
aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
|
66
|
+
end
|
67
|
+
if (
|
68
|
+
[AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
|
69
|
+
!@tokenizer.is_fast
|
70
|
+
)
|
71
|
+
raise ArgumentError,
|
72
|
+
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" +
|
73
|
+
' to `"simple"` or use a fast tokenizer.'
|
74
|
+
end
|
75
|
+
postprocess_params[:aggregation_strategy] = aggregation_strategy
|
76
|
+
end
|
77
|
+
if !ignore_labels.nil?
|
78
|
+
postprocess_params[:ignore_labels] = ignore_labels
|
79
|
+
end
|
80
|
+
if !stride.nil?
|
81
|
+
if stride >= @tokenizer.model_max_length
|
82
|
+
raise ArgumentError,
|
83
|
+
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
|
84
|
+
end
|
85
|
+
if aggregation_strategy == AggregationStrategy::NONE
|
86
|
+
raise ArgumentError,
|
87
|
+
"`stride` was provided to process all the text but `aggregation_strategy=" +
|
88
|
+
"\"#{aggregation_strategy}\"`, please select another one instead."
|
89
|
+
else
|
90
|
+
if @tokenizer.is_fast
|
91
|
+
tokenizer_params = {
|
92
|
+
return_overflowing_tokens: true,
|
93
|
+
padding: true,
|
94
|
+
stride: stride
|
95
|
+
}
|
96
|
+
preprocess_params[:tokenizer_params] = tokenizer_params
|
97
|
+
else
|
98
|
+
raise ArgumentError,
|
99
|
+
"`stride` was provided to process all the text but you're using a slow tokenizer." +
|
100
|
+
" Please use a fast tokenizer."
|
101
|
+
end
|
102
|
+
end
|
103
|
+
end
|
104
|
+
[preprocess_params, {}, postprocess_params]
|
105
|
+
end
|
106
|
+
|
107
|
+
def preprocess(sentence, offset_mapping: nil, **preprocess_params)
|
108
|
+
tokenizer_params = preprocess_params.delete(:tokenizer_params) { {} }
|
109
|
+
truncation = @tokenizer.model_max_length && @tokenizer.model_max_length > 0
|
110
|
+
inputs = @tokenizer.(
|
111
|
+
sentence,
|
112
|
+
return_tensors: @framework,
|
113
|
+
truncation: truncation,
|
114
|
+
return_special_tokens_mask: true,
|
115
|
+
return_offsets_mapping: @tokenizer.is_fast,
|
116
|
+
**tokenizer_params
|
117
|
+
)
|
118
|
+
inputs.delete(:overflow_to_sample_mapping)
|
119
|
+
num_chunks = inputs[:input_ids].length
|
120
|
+
|
121
|
+
num_chunks.times do |i|
|
122
|
+
if @framework == "tf"
|
123
|
+
raise Todo
|
124
|
+
else
|
125
|
+
model_inputs = inputs.to_h { |k, v| [k, v[i].unsqueeze(0)] }
|
126
|
+
end
|
127
|
+
if !@offset_mapping.nil?
|
128
|
+
model_inputs[:offset_mapping] = offset_mapping
|
129
|
+
end
|
130
|
+
model_inputs[:sentence] = i == 0 ? sentence : nil
|
131
|
+
model_inputs[:is_last] = (i == num_chunks - 1)
|
132
|
+
|
133
|
+
yield model_inputs
|
134
|
+
end
|
135
|
+
end
|
136
|
+
|
137
|
+
def _forward(model_inputs)
|
138
|
+
# Forward
|
139
|
+
special_tokens_mask = model_inputs.delete(:special_tokens_mask)
|
140
|
+
offset_mapping = model_inputs.delete(:offset_mapping)
|
141
|
+
sentence = model_inputs.delete(:sentence)
|
142
|
+
is_last = model_inputs.delete(:is_last)
|
143
|
+
if @framework == "tf"
|
144
|
+
logits = @model.(**model_inputs)[0]
|
145
|
+
else
|
146
|
+
output = @model.(**model_inputs)
|
147
|
+
logits = output.is_a?(Hash) ? output[:logits] : output[0]
|
148
|
+
end
|
149
|
+
|
150
|
+
{
|
151
|
+
logits: logits,
|
152
|
+
special_tokens_mask: special_tokens_mask,
|
153
|
+
offset_mapping: offset_mapping,
|
154
|
+
sentence: sentence,
|
155
|
+
is_last: is_last,
|
156
|
+
**model_inputs
|
157
|
+
}
|
158
|
+
end
|
159
|
+
|
160
|
+
def postprocess(all_outputs, aggregation_strategy: AggregationStrategy::NONE, ignore_labels: nil)
|
161
|
+
if ignore_labels.nil?
|
162
|
+
ignore_labels = ["O"]
|
163
|
+
end
|
164
|
+
all_entities = []
|
165
|
+
all_outputs.each do |model_outputs|
|
166
|
+
logits = model_outputs[:logits][0].numo
|
167
|
+
sentence = all_outputs[0][:sentence]
|
168
|
+
input_ids = model_outputs[:input_ids][0]
|
169
|
+
offset_mapping = (
|
170
|
+
!model_outputs[:offset_mapping].nil? ? model_outputs[:offset_mapping][0] : nil
|
171
|
+
)
|
172
|
+
special_tokens_mask = model_outputs[:special_tokens_mask][0].numo
|
173
|
+
|
174
|
+
maxes = logits.max(axis: -1).expand_dims(-1)
|
175
|
+
shifted_exp = Numo::NMath.exp(logits - maxes)
|
176
|
+
scores = shifted_exp / shifted_exp.sum(axis: -1).expand_dims(-1)
|
177
|
+
|
178
|
+
if @framework == "tf"
|
179
|
+
raise Todo
|
180
|
+
end
|
181
|
+
|
182
|
+
pre_entities = gather_pre_entities(
|
183
|
+
sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
|
184
|
+
)
|
185
|
+
grouped_entities = aggregate(pre_entities, aggregation_strategy)
|
186
|
+
# Filter anything that is in self.ignore_labels
|
187
|
+
entities =
|
188
|
+
grouped_entities.select do |entity|
|
189
|
+
!ignore_labels.include?(entity[:entity]) && !ignore_labels.include?(entity[:entity_group])
|
190
|
+
end
|
191
|
+
all_entities.concat(entities)
|
192
|
+
end
|
193
|
+
num_chunks = all_outputs.length
|
194
|
+
if num_chunks > 1
|
195
|
+
all_entities = aggregate_overlapping_entities(all_entities)
|
196
|
+
end
|
197
|
+
all_entities
|
198
|
+
end
|
199
|
+
|
200
|
+
def gather_pre_entities(
|
201
|
+
sentence,
|
202
|
+
input_ids,
|
203
|
+
scores,
|
204
|
+
offset_mapping,
|
205
|
+
special_tokens_mask,
|
206
|
+
aggregation_strategy
|
207
|
+
)
|
208
|
+
pre_entities = []
|
209
|
+
scores.each_over_axis(0).with_index do |token_scores, idx|
|
210
|
+
# Filter special_tokens
|
211
|
+
if special_tokens_mask[idx] != 0
|
212
|
+
next
|
213
|
+
end
|
214
|
+
|
215
|
+
word = @tokenizer.convert_ids_to_tokens(input_ids[idx].to_i)
|
216
|
+
if !offset_mapping.nil?
|
217
|
+
start_ind, end_ind = offset_mapping[idx].to_a
|
218
|
+
if !start_ind.is_a?(Integer)
|
219
|
+
if @framework == "pt"
|
220
|
+
start_ind = start_ind.item
|
221
|
+
end_ind = end_ind.item
|
222
|
+
end
|
223
|
+
end
|
224
|
+
word_ref = sentence[start_ind...end_ind]
|
225
|
+
if @tokenizer.instance_variable_get(:@tokenizer).respond_to?(:continuing_subword_prefix)
|
226
|
+
# This is a BPE, word aware tokenizer, there is a correct way
|
227
|
+
# to fuse tokens
|
228
|
+
is_subword = word.length != word_ref.length
|
229
|
+
else
|
230
|
+
is_subword = start_ind > 0 && !sentence[(start_ind - 1)...(start_ind + 1)].include?(" ")
|
231
|
+
end
|
232
|
+
|
233
|
+
if input_ids[idx].to_i == @tokenizer.unk_token_id
|
234
|
+
word = word_ref
|
235
|
+
is_subword = false
|
236
|
+
end
|
237
|
+
else
|
238
|
+
start_ind = nil
|
239
|
+
end_ind = nil
|
240
|
+
is_subword = nil
|
241
|
+
end
|
242
|
+
|
243
|
+
pre_entity = {
|
244
|
+
word: word,
|
245
|
+
scores: token_scores,
|
246
|
+
start: start_ind,
|
247
|
+
end: end_ind,
|
248
|
+
index: idx,
|
249
|
+
is_subword: is_subword
|
250
|
+
}
|
251
|
+
pre_entities << pre_entity
|
252
|
+
end
|
253
|
+
pre_entities
|
254
|
+
end
|
255
|
+
|
256
|
+
def aggregate(pre_entities, aggregation_strategy)
|
257
|
+
if [AggregationStrategy::NONE, AggregationStrategy::SIMPLE].include?(aggregation_strategy)
|
258
|
+
entities = []
|
259
|
+
pre_entities.each do |pre_entity|
|
260
|
+
entity_idx = pre_entity[:scores].argmax
|
261
|
+
score = pre_entity[:scores][entity_idx]
|
262
|
+
entity = {
|
263
|
+
entity: @model.config.id2label[entity_idx],
|
264
|
+
score: score,
|
265
|
+
index: pre_entity[:index],
|
266
|
+
word: pre_entity[:word],
|
267
|
+
start: pre_entity[:start],
|
268
|
+
end: pre_entity[:end]
|
269
|
+
}
|
270
|
+
entities << entity
|
271
|
+
end
|
272
|
+
else
|
273
|
+
entities = aggregate_words(pre_entities, aggregation_strategy)
|
274
|
+
end
|
275
|
+
|
276
|
+
if aggregation_strategy == AggregationStrategy::NONE
|
277
|
+
return entities
|
278
|
+
end
|
279
|
+
group_entities(entities)
|
280
|
+
end
|
281
|
+
end
|
282
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Transformers
|
2
|
+
module ClassAttribute
|
3
|
+
def class_attribute(name, default = nil)
|
4
|
+
singleton_class.attr_writer name
|
5
|
+
var = "@#{name}"
|
6
|
+
instance_variable_set(var, default)
|
7
|
+
singleton_class.define_method(name) do
|
8
|
+
# ancestors includes current module
|
9
|
+
ancestors.find { |c| c.instance_variable_defined?(var) }.instance_variable_get(var)
|
10
|
+
end
|
11
|
+
define_method(name) do
|
12
|
+
self.class.send(name)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
module Copy
|
18
|
+
def self.deepcopy(value, memo = {})
|
19
|
+
key = value.object_id
|
20
|
+
if !memo.key?(key)
|
21
|
+
copy = value.dup
|
22
|
+
memo[key] = copy
|
23
|
+
if value.is_a?(Hash)
|
24
|
+
copy.transform_keys! { |k| deepcopy(k, memo) }
|
25
|
+
copy.transform_values! { |v| deepcopy(v, memo) }
|
26
|
+
elsif value.is_a?(Array)
|
27
|
+
copy.map! { |v| deepcopy(v, memo) }
|
28
|
+
end
|
29
|
+
end
|
30
|
+
memo[key]
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Transformers
|
2
|
+
class SentenceTransformer
|
3
|
+
def initialize(model_id)
|
4
|
+
@model_id = model_id
|
5
|
+
@tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
|
6
|
+
@model = Transformers::AutoModel.from_pretrained(model_id)
|
7
|
+
end
|
8
|
+
|
9
|
+
def encode(sentences)
|
10
|
+
singular = sentences.is_a?(String)
|
11
|
+
sentences = [sentences] if singular
|
12
|
+
|
13
|
+
input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
|
14
|
+
output = Torch.no_grad { @model.(**input) }[0]
|
15
|
+
|
16
|
+
# TODO check modules.json
|
17
|
+
if [
|
18
|
+
"sentence-transformers/all-MiniLM-L6-v2",
|
19
|
+
"sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
|
20
|
+
].include?(@model_id)
|
21
|
+
output = mean_pooling(output, input[:attention_mask])
|
22
|
+
output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
|
23
|
+
else
|
24
|
+
output = output[0.., 0].to_a
|
25
|
+
end
|
26
|
+
|
27
|
+
singular ? output[0] : output
|
28
|
+
end
|
29
|
+
|
30
|
+
private
|
31
|
+
|
32
|
+
def mean_pooling(output, attention_mask)
|
33
|
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
|
34
|
+
Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,152 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class PreTrainedTokenizer < PreTrainedTokenizerBase
|
17
|
+
def initialize(**kwargs)
|
18
|
+
|
19
|
+
# 2. init `_added_tokens_decoder` if child class did not
|
20
|
+
if !instance_variable_defined?(:@added_tokens_decoder)
|
21
|
+
@added_tokens_decoder = {}
|
22
|
+
end
|
23
|
+
|
24
|
+
# 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
|
25
|
+
@added_tokens_decoder.merge!(kwargs.delete(:added_tokens_decoder) { {} })
|
26
|
+
@added_tokens_encoder = @added_tokens_decoder.to_h { |k, v| [k.content, v] }
|
27
|
+
|
28
|
+
# 4 init the parent class
|
29
|
+
super(**kwargs)
|
30
|
+
end
|
31
|
+
|
32
|
+
def is_fast
|
33
|
+
false
|
34
|
+
end
|
35
|
+
|
36
|
+
def vocab_size
|
37
|
+
raise NotImplementedError
|
38
|
+
end
|
39
|
+
|
40
|
+
def tokenize(text, **kwargs)
|
41
|
+
raise Todo
|
42
|
+
end
|
43
|
+
|
44
|
+
def _encode_plus(
|
45
|
+
text:,
|
46
|
+
text_pair: nil,
|
47
|
+
add_special_tokens: true,
|
48
|
+
padding_strategy: PaddingStrategy::DO_NOT_PAD,
|
49
|
+
truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
|
50
|
+
max_length: nil,
|
51
|
+
stride: 0,
|
52
|
+
is_split_into_words: false,
|
53
|
+
pad_to_multiple_of: nil,
|
54
|
+
return_tensors: nil,
|
55
|
+
return_token_type_ids: nil,
|
56
|
+
return_attention_mask: nil,
|
57
|
+
return_overflowing_tokens: false,
|
58
|
+
return_special_tokens_mask: false,
|
59
|
+
return_offsets_mapping: false,
|
60
|
+
return_length: false,
|
61
|
+
verbose: true,
|
62
|
+
**kwargs
|
63
|
+
)
|
64
|
+
get_input_ids = lambda do |text|
|
65
|
+
if text.is_a?(String)
|
66
|
+
tokens = tokenize(text, **kwargs)
|
67
|
+
convert_tokens_to_ids(tokens)
|
68
|
+
elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(String)
|
69
|
+
if is_split_into_words
|
70
|
+
raise Todo
|
71
|
+
else
|
72
|
+
convert_tokens_to_ids(text)
|
73
|
+
end
|
74
|
+
elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(Integer)
|
75
|
+
text
|
76
|
+
else
|
77
|
+
if is_split_into_words
|
78
|
+
raise ArgumentError,
|
79
|
+
"Input #{text} is not valid. Should be a string or a list/tuple of strings when" +
|
80
|
+
" `is_split_into_words=True`."
|
81
|
+
else
|
82
|
+
raise ArgumentError,
|
83
|
+
"Input #{text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" +
|
84
|
+
" integers."
|
85
|
+
end
|
86
|
+
end
|
87
|
+
end
|
88
|
+
|
89
|
+
if return_offsets_mapping
|
90
|
+
raise RuntimeError,
|
91
|
+
"return_offset_mapping is not available when using Ruby tokenizers. " +
|
92
|
+
"To use this feature, change your tokenizer to one deriving from " +
|
93
|
+
"Transformers::PreTrainedTokenizerFast. " +
|
94
|
+
"More information on available tokenizers at " +
|
95
|
+
"https://github.com/huggingface/transformers/pull/2674"
|
96
|
+
end
|
97
|
+
|
98
|
+
first_ids = get_input_ids.(text)
|
99
|
+
second_ids = !text_pair.nil? ? get_input_ids.(text_pair) : nil
|
100
|
+
|
101
|
+
prepare_for_model(
|
102
|
+
first_ids,
|
103
|
+
pair_ids: second_ids,
|
104
|
+
add_special_tokens: add_special_tokens,
|
105
|
+
padding: padding_strategy,
|
106
|
+
truncation: truncation_strategy,
|
107
|
+
max_length: max_length,
|
108
|
+
stride: stride,
|
109
|
+
pad_to_multiple_of: pad_to_multiple_of,
|
110
|
+
return_tensors: return_tensors,
|
111
|
+
prepend_batch_axis: true,
|
112
|
+
return_attention_mask: return_attention_mask,
|
113
|
+
return_token_type_ids: return_token_type_ids,
|
114
|
+
return_overflowing_tokens: return_overflowing_tokens,
|
115
|
+
return_special_tokens_mask: return_special_tokens_mask,
|
116
|
+
return_length: return_length,
|
117
|
+
verbose: verbose
|
118
|
+
)
|
119
|
+
end
|
120
|
+
|
121
|
+
def convert_tokens_to_ids(tokens)
|
122
|
+
if tokens.nil?
|
123
|
+
return nil
|
124
|
+
end
|
125
|
+
|
126
|
+
if tokens.is_a?(String)
|
127
|
+
return _convert_token_to_id_with_added_voc(tokens)
|
128
|
+
end
|
129
|
+
|
130
|
+
ids = []
|
131
|
+
tokens.each do |token|
|
132
|
+
ids << _convert_token_to_id_with_added_voc(token)
|
133
|
+
end
|
134
|
+
ids
|
135
|
+
end
|
136
|
+
|
137
|
+
def _convert_token_to_id_with_added_voc(token)
|
138
|
+
if token.nil?
|
139
|
+
return nil
|
140
|
+
end
|
141
|
+
|
142
|
+
if @added_tokens_encoder.include?(token)
|
143
|
+
return @added_tokens_encoder[token]
|
144
|
+
end
|
145
|
+
_convert_token_to_id(token)
|
146
|
+
end
|
147
|
+
|
148
|
+
def _convert_token_to_id(token)
|
149
|
+
raise NotImplementedError
|
150
|
+
end
|
151
|
+
end
|
152
|
+
end
|