transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
class PretrainedConfig
|
18
|
+
extend ClassAttribute
|
19
|
+
|
20
|
+
class_attribute :model_type, ""
|
21
|
+
class_attribute :attribute_map, {}
|
22
|
+
|
23
|
+
# TODO support setter
|
24
|
+
def method_missing(m, *args, **kwargs)
|
25
|
+
if self.class.attribute_map.include?(m)
|
26
|
+
instance_variable_get("@#{self.class.attribute_map[m]}")
|
27
|
+
else
|
28
|
+
super
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
# TODO support setter
|
33
|
+
def respond_to_missing?(m, include_private = true)
|
34
|
+
self.class.attribute_map.include?(m) || super
|
35
|
+
end
|
36
|
+
|
37
|
+
attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
|
38
|
+
:chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
|
39
|
+
:problem_type, :id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
|
40
|
+
|
41
|
+
def initialize(**kwargs)
|
42
|
+
@return_dict = kwargs.delete(:return_dict) { true }
|
43
|
+
@output_hidden_states = kwargs.delete(:output_hidden_states) { false }
|
44
|
+
@output_attentions = kwargs.delete(:output_attentions) { false }
|
45
|
+
@pruned_heads = kwargs.delete(:pruned_heads) { {} }
|
46
|
+
@tie_word_embeddings = kwargs.delete(:tie_word_embeddings) { true }
|
47
|
+
@chunk_size_feed_forward = kwargs.delete(:chunk_size_feed_forward) { 0 }
|
48
|
+
|
49
|
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
50
|
+
@is_encoder_decoder = kwargs.delete(:is_encoder_decoder) { false }
|
51
|
+
@is_decoder = kwargs.delete(:is_decoder) { false }
|
52
|
+
@cross_attention_hidden_size = kwargs.delete(:cross_attention_hidden_size)
|
53
|
+
@add_cross_attention = kwargs.delete(:add_cross_attention) { false }
|
54
|
+
@tie_encoder_decoder = kwargs.delete(:tie_encoder_decoder) { false }
|
55
|
+
|
56
|
+
# Fine-tuning task arguments
|
57
|
+
@architectures = kwargs.delete(:architectures)
|
58
|
+
@finetuning_task = kwargs.delete(:finetuning_task)
|
59
|
+
@id2label = kwargs.delete(:id2label)
|
60
|
+
@label2id = kwargs.delete(:label2id)
|
61
|
+
if !@label2id.nil? && !@label2id.is_a?(Hash)
|
62
|
+
raise ArgumentError, "Argument label2id should be a dictionary."
|
63
|
+
end
|
64
|
+
if !@id2label.nil?
|
65
|
+
if !@id2label.is_a?(Hash)
|
66
|
+
raise ArgumentError, "Argument id2label should be a dictionary."
|
67
|
+
end
|
68
|
+
num_labels = kwargs.delete(:num_labels)
|
69
|
+
if !num_labels.nil? && id2label.length != num_labels
|
70
|
+
raise Todo
|
71
|
+
end
|
72
|
+
@id2label = @id2label.transform_keys(&:to_i)
|
73
|
+
# Keys are always strings in JSON so convert ids to int here.
|
74
|
+
else
|
75
|
+
self.num_labels = kwargs.delete(:num_labels) { 2 }
|
76
|
+
end
|
77
|
+
|
78
|
+
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
79
|
+
@tokenizer_class = kwargs.delete(:tokenizer_class)
|
80
|
+
@prefix = kwargs.delete(:prefix)
|
81
|
+
@bos_token_id = kwargs.delete(:bos_token_id)
|
82
|
+
@pad_token_id = kwargs.delete(:pad_token_id)
|
83
|
+
@eos_token_id = kwargs.delete(:eos_token_id)
|
84
|
+
@sep_token_id = kwargs.delete(:sep_token_id)
|
85
|
+
|
86
|
+
# regression / multi-label classification
|
87
|
+
@problem_type = kwargs.delete(:problem_type)
|
88
|
+
|
89
|
+
# Name or path to the pretrained checkpoint
|
90
|
+
@name_or_path = kwargs.delete(:name_or_path).to_s
|
91
|
+
# Config hash
|
92
|
+
@commit_hash = kwargs.delete(:_commit_hash)
|
93
|
+
|
94
|
+
# TODO set kwargs
|
95
|
+
@gradient_checkpointing = kwargs[:gradient_checkpointing]
|
96
|
+
@output_past = kwargs[:output_past]
|
97
|
+
@tie_weights_ = kwargs[:tie_weights_]
|
98
|
+
end
|
99
|
+
|
100
|
+
def name_or_path
|
101
|
+
@name_or_path
|
102
|
+
end
|
103
|
+
|
104
|
+
def name_or_path=(value)
|
105
|
+
@name_or_path = value.to_s
|
106
|
+
end
|
107
|
+
|
108
|
+
def num_labels
|
109
|
+
@id2label.length
|
110
|
+
end
|
111
|
+
|
112
|
+
def num_labels=(num_labels)
|
113
|
+
if @id2label.nil? || @id2label.length != num_labels
|
114
|
+
@id2label = num_labels.times.to_h { |i| [i, "LABEL_#{i}"] }
|
115
|
+
@label2id = @id2label.invert
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def _attn_implementation
|
120
|
+
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
121
|
+
if instance_variable_defined?(:@attn_implementation_internal)
|
122
|
+
if instance_variable_get(:@attn_implementation_internal).nil?
|
123
|
+
# `config.attn_implementation` should never be None, for backward compatibility.
|
124
|
+
"eager"
|
125
|
+
else
|
126
|
+
@attn_implementation_internal
|
127
|
+
end
|
128
|
+
else
|
129
|
+
"eager"
|
130
|
+
end
|
131
|
+
end
|
132
|
+
|
133
|
+
def use_return_dict
|
134
|
+
@return_dict
|
135
|
+
end
|
136
|
+
|
137
|
+
def to_s
|
138
|
+
"#{self.class.name} #{to_json_string}"
|
139
|
+
end
|
140
|
+
|
141
|
+
def to_diff_dict
|
142
|
+
config_dict = to_dict
|
143
|
+
|
144
|
+
# get the default config dict
|
145
|
+
default_config_dict = PretrainedConfig.new.to_dict
|
146
|
+
|
147
|
+
serializable_config_dict = {}
|
148
|
+
|
149
|
+
config_dict.each do |key, value|
|
150
|
+
key = :_name_or_path if key == :name_or_path
|
151
|
+
if !default_config_dict.include?(key) || value != default_config_dict[key] || key == :transformers_version
|
152
|
+
serializable_config_dict[key] = value
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
serializable_config_dict
|
157
|
+
end
|
158
|
+
|
159
|
+
def _dict
|
160
|
+
instance_variables.to_h { |k| [k[1..].to_sym, instance_variable_get(k)] }
|
161
|
+
end
|
162
|
+
|
163
|
+
def to_dict
|
164
|
+
output = Copy.deepcopy(_dict)
|
165
|
+
output[:model_type] = self.class.model_type
|
166
|
+
output.delete(:_auto_class)
|
167
|
+
output.delete(:_commit_hash)
|
168
|
+
output.delete(:_attn_implementation_internal)
|
169
|
+
|
170
|
+
# Transformers version when serializing the model
|
171
|
+
output[:transformers_version] = VERSION
|
172
|
+
|
173
|
+
output
|
174
|
+
end
|
175
|
+
|
176
|
+
def to_json_string(use_diff: true)
|
177
|
+
if use_diff == true
|
178
|
+
config_dict = to_diff_dict
|
179
|
+
else
|
180
|
+
config_dict = to_dict
|
181
|
+
end
|
182
|
+
JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
|
183
|
+
end
|
184
|
+
|
185
|
+
class << self
|
186
|
+
def from_pretrained(
|
187
|
+
pretrained_model_name_or_path,
|
188
|
+
cache_dir: nil,
|
189
|
+
force_download: false,
|
190
|
+
local_files_only: false,
|
191
|
+
token: nil,
|
192
|
+
revision: "main",
|
193
|
+
**kwargs
|
194
|
+
)
|
195
|
+
config_dict, kwargs = get_config_dict(pretrained_model_name_or_path, **kwargs)
|
196
|
+
|
197
|
+
from_dict(config_dict, **kwargs)
|
198
|
+
end
|
199
|
+
|
200
|
+
def from_dict(config_dict, **kwargs)
|
201
|
+
return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
|
202
|
+
|
203
|
+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
204
|
+
if kwargs.include?(:_commit_hash) && config_dict.include?(:_commit_hash)
|
205
|
+
kwargs[:_commit_hash] = config_dict[:_commit_hash]
|
206
|
+
end
|
207
|
+
|
208
|
+
config = new(**config_dict)
|
209
|
+
|
210
|
+
kwargs.each do |key, value|
|
211
|
+
if config.respond_to?("#{key}=")
|
212
|
+
config.public_send("#{key}=", value)
|
213
|
+
end
|
214
|
+
end
|
215
|
+
|
216
|
+
Transformers.logger.info("Model config #{config}")
|
217
|
+
if return_unused_kwargs
|
218
|
+
[config, kwargs]
|
219
|
+
else
|
220
|
+
config
|
221
|
+
end
|
222
|
+
end
|
223
|
+
|
224
|
+
def get_config_dict(pretrained_model_name_or_path, **kwargs)
|
225
|
+
# Get config dict associated with the base config file
|
226
|
+
config_dict, kwargs = _get_config_dict(pretrained_model_name_or_path, **kwargs)
|
227
|
+
|
228
|
+
[config_dict, kwargs]
|
229
|
+
end
|
230
|
+
|
231
|
+
private
|
232
|
+
|
233
|
+
def _get_config_dict(pretrained_model_name_or_path, **kwargs)
|
234
|
+
cache_dir = kwargs.delete(:cache_dir)
|
235
|
+
force_download = kwargs.delete(:force_download) { false }
|
236
|
+
resume_download = kwargs.delete(:resume_download) { false }
|
237
|
+
proxies = kwargs.delete(:proxies)
|
238
|
+
token = kwargs.delete(:token)
|
239
|
+
local_files_only = kwargs.delete(:local_files_only) { false }
|
240
|
+
revision = kwargs.delete(:revision)
|
241
|
+
_trust_remote_code = kwargs.delete(:trust_remote_code)
|
242
|
+
subfolder = kwargs.delete(:subfolder) { "" }
|
243
|
+
_from_pipeline = kwargs.delete(:_from_pipeline)
|
244
|
+
from_auto_class = kwargs.delete(:_from_auto) { false }
|
245
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
246
|
+
|
247
|
+
user_agent = {file_type: "config", from_auto_class: from_auto_class}
|
248
|
+
|
249
|
+
is_local = Dir.exist?(pretrained_model_name_or_path)
|
250
|
+
configuration_file = kwargs.delete(:_configuration_file) || CONFIG_NAME
|
251
|
+
|
252
|
+
resolved_config_file = Utils::Hub.cached_file(
|
253
|
+
pretrained_model_name_or_path,
|
254
|
+
configuration_file,
|
255
|
+
cache_dir: cache_dir,
|
256
|
+
force_download: force_download,
|
257
|
+
proxies: proxies,
|
258
|
+
resume_download: resume_download,
|
259
|
+
local_files_only: local_files_only,
|
260
|
+
token: token,
|
261
|
+
user_agent: user_agent,
|
262
|
+
revision: revision,
|
263
|
+
subfolder: subfolder,
|
264
|
+
_commit_hash: commit_hash
|
265
|
+
)
|
266
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
267
|
+
|
268
|
+
config_dict = _dict_from_json_file(resolved_config_file)
|
269
|
+
config_dict[:_commit_hash] = commit_hash
|
270
|
+
|
271
|
+
if is_local
|
272
|
+
Transformers.logger.info("loading configuration file #{resolved_config_file}")
|
273
|
+
else
|
274
|
+
Transformers.logger.info("loading configuration file #{configuration_file} from cache at #{resolved_config_file}")
|
275
|
+
end
|
276
|
+
|
277
|
+
[config_dict, kwargs]
|
278
|
+
end
|
279
|
+
|
280
|
+
def _dict_from_json_file(json_file)
|
281
|
+
JSON.load_file(json_file).transform_keys(&:to_sym)
|
282
|
+
end
|
283
|
+
end
|
284
|
+
end
|
285
|
+
end
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module ConvertSlowTokenizer
|
17
|
+
class Converter
|
18
|
+
def initialize(original_tokenizer)
|
19
|
+
@original_tokenizer = original_tokenizer
|
20
|
+
end
|
21
|
+
|
22
|
+
def converted
|
23
|
+
raise NotImplementedError
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
class BertConverter < Converter
|
28
|
+
def converted
|
29
|
+
vocab = @original_tokenizer.vocab
|
30
|
+
tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::WordPiece.new(vocab: vocab, unk_token: @original_tokenizer.unk_token.to_s))
|
31
|
+
|
32
|
+
tokenize_chinese_chars = false
|
33
|
+
strip_accents = false
|
34
|
+
do_lower_case = false
|
35
|
+
if @original_tokenizer.basic_tokenizer
|
36
|
+
tokenize_chinese_chars = @original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
37
|
+
strip_accents = @original_tokenizer.basic_tokenizer.strip_accents
|
38
|
+
do_lower_case = @original_tokenizer.basic_tokenizer.do_lower_case
|
39
|
+
end
|
40
|
+
|
41
|
+
tokenizer.normalizer =
|
42
|
+
Tokenizers::Normalizers::BertNormalizer.new(
|
43
|
+
clean_text: true,
|
44
|
+
handle_chinese_chars: tokenize_chinese_chars,
|
45
|
+
strip_accents: strip_accents,
|
46
|
+
lowercase: do_lower_case,
|
47
|
+
)
|
48
|
+
tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
|
49
|
+
|
50
|
+
cls = @original_tokenizer.cls_token.to_s
|
51
|
+
sep = @original_tokenizer.sep_token.to_s
|
52
|
+
cls_token_id = @original_tokenizer.cls_token_id
|
53
|
+
sep_token_id = @original_tokenizer.sep_token_id
|
54
|
+
|
55
|
+
tokenizer.post_processor =
|
56
|
+
Tokenizers::Processors::TemplateProcessing.new(
|
57
|
+
single: "#{cls}:0 $A:0 #{sep}:0",
|
58
|
+
pair: "#{cls}:0 $A:0 #{sep}:0 $B:1 #{sep}:1",
|
59
|
+
special_tokens: [
|
60
|
+
[cls, cls_token_id],
|
61
|
+
[sep, sep_token_id]
|
62
|
+
]
|
63
|
+
)
|
64
|
+
tokenizer.decoder = Tokenizers::Decoders::WordPiece.new(prefix: "##")
|
65
|
+
|
66
|
+
tokenizer
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
SLOW_TO_FAST_CONVERTERS = {
|
71
|
+
"BertTokenizer" => BertConverter,
|
72
|
+
"DistilBertTokenizer" => BertConverter
|
73
|
+
}
|
74
|
+
|
75
|
+
def self.convert_slow_tokenizer(transformer_tokenizer)
|
76
|
+
tokenizer_class_name = transformer_tokenizer.class.name.split("::").last
|
77
|
+
|
78
|
+
if !SLOW_TO_FAST_CONVERTERS.include?(tokenizer_class_name)
|
79
|
+
raise ArgumentError,
|
80
|
+
"An instance of tokenizer class #{tokenizer_class_name} cannot be converted in a Fast tokenizer instance." +
|
81
|
+
" No converter was found. Currently available slow->fast convertors:" +
|
82
|
+
" #{SLOW_TO_FAST_CONVERTERS.keys}"
|
83
|
+
end
|
84
|
+
|
85
|
+
converter_class = SLOW_TO_FAST_CONVERTERS.fetch(tokenizer_class_name)
|
86
|
+
|
87
|
+
converter_class.new(transformer_tokenizer).converted
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class SquadExample
|
17
|
+
attr_reader :question_text, :context_text
|
18
|
+
|
19
|
+
def initialize(
|
20
|
+
qas_id,
|
21
|
+
question_text,
|
22
|
+
context_text,
|
23
|
+
answer_text,
|
24
|
+
start_position_character,
|
25
|
+
title,
|
26
|
+
answers: [],
|
27
|
+
is_impossible: false
|
28
|
+
)
|
29
|
+
@qas_id = qas_id
|
30
|
+
@question_text = question_text
|
31
|
+
@context_text = context_text
|
32
|
+
@answer_text = answer_text
|
33
|
+
@title = title
|
34
|
+
@is_impossible = is_impossible
|
35
|
+
@answers = answers
|
36
|
+
|
37
|
+
@start_position, @end_position = 0, 0
|
38
|
+
|
39
|
+
doc_tokens = []
|
40
|
+
char_to_word_offset = []
|
41
|
+
prev_is_whitespace = true
|
42
|
+
|
43
|
+
# Split on whitespace so that different tokens may be attributed to their original position.
|
44
|
+
@context_text.each_char do |c|
|
45
|
+
if _is_whitespace(c)
|
46
|
+
prev_is_whitespace = true
|
47
|
+
else
|
48
|
+
if prev_is_whitespace
|
49
|
+
doc_tokens << c
|
50
|
+
else
|
51
|
+
doc_tokens[-1] += c
|
52
|
+
end
|
53
|
+
prev_is_whitespace = false
|
54
|
+
end
|
55
|
+
char_to_word_offset << (doc_tokens.length - 1)
|
56
|
+
end
|
57
|
+
|
58
|
+
@doc_tokens = doc_tokens
|
59
|
+
@char_to_word_offset = char_to_word_offset
|
60
|
+
|
61
|
+
# Start and end positions only has a value during evaluation.
|
62
|
+
if !start_position_character.nil? && !is_impossible
|
63
|
+
@start_position = char_to_word_offset[start_position_character]
|
64
|
+
@end_position = char_to_word_offset[
|
65
|
+
[start_position_character + answer_text.length - 1, char_to_word_offset.length - 1].min
|
66
|
+
]
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
def _is_whitespace(c)
|
71
|
+
c == " " || c == "\t" || c == "\r" || c == "\n" || c.ord == 0x202F
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
class SquadFeatures
|
76
|
+
def initialize(
|
77
|
+
input_ids:,
|
78
|
+
attention_mask:,
|
79
|
+
token_type_ids:,
|
80
|
+
cls_index:,
|
81
|
+
p_mask:,
|
82
|
+
example_index:,
|
83
|
+
unique_id:,
|
84
|
+
paragraph_len:,
|
85
|
+
token_is_max_context:,
|
86
|
+
tokens:,
|
87
|
+
token_to_orig_map:,
|
88
|
+
start_position:,
|
89
|
+
end_position:,
|
90
|
+
is_impossible:,
|
91
|
+
qas_id: nil,
|
92
|
+
encoding: nil
|
93
|
+
)
|
94
|
+
@input_ids = input_ids
|
95
|
+
@attention_mask = attention_mask
|
96
|
+
@token_type_ids = token_type_ids
|
97
|
+
@cls_index = cls_index
|
98
|
+
@p_mask = p_mask
|
99
|
+
|
100
|
+
@example_index = example_index
|
101
|
+
@unique_id = unique_id
|
102
|
+
@paragraph_len = paragraph_len
|
103
|
+
@token_is_max_context = token_is_max_context
|
104
|
+
@tokens = tokens
|
105
|
+
@token_to_orig_map = token_to_orig_map
|
106
|
+
|
107
|
+
@start_position = start_position
|
108
|
+
@end_position = end_position
|
109
|
+
@is_impossible = is_impossible
|
110
|
+
@qas_id = qas_id
|
111
|
+
|
112
|
+
@encoding = encoding
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module DynamicModuleUtils
|
17
|
+
# TODO improve
|
18
|
+
def self.resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code)
|
19
|
+
if trust_remote_code
|
20
|
+
raise Error, "trust_remote_code not supported"
|
21
|
+
end
|
22
|
+
trust_remote_code
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BatchFeature
|
17
|
+
def initialize(data:, tensor_type:)
|
18
|
+
@data = data
|
19
|
+
convert_to_tensors(tensor_type: tensor_type)
|
20
|
+
end
|
21
|
+
|
22
|
+
def to_h
|
23
|
+
@data
|
24
|
+
end
|
25
|
+
|
26
|
+
def [](item)
|
27
|
+
@data[item]
|
28
|
+
end
|
29
|
+
|
30
|
+
def keys
|
31
|
+
@data.keys
|
32
|
+
end
|
33
|
+
|
34
|
+
def values
|
35
|
+
@data.values
|
36
|
+
end
|
37
|
+
|
38
|
+
def items
|
39
|
+
@data
|
40
|
+
end
|
41
|
+
|
42
|
+
def _get_is_as_tensor_fns(tensor_type: nil)
|
43
|
+
if tensor_type.nil?
|
44
|
+
return [nil, nil]
|
45
|
+
end
|
46
|
+
|
47
|
+
as_tensor = lambda do |value|
|
48
|
+
if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
|
49
|
+
value = Numo::NArray.cast(value)
|
50
|
+
end
|
51
|
+
Torch.tensor(value)
|
52
|
+
end
|
53
|
+
|
54
|
+
is_tensor = Torch.method(:tensor?)
|
55
|
+
|
56
|
+
[is_tensor, as_tensor]
|
57
|
+
end
|
58
|
+
|
59
|
+
def convert_to_tensors(tensor_type: nil)
|
60
|
+
if tensor_type.nil?
|
61
|
+
return self
|
62
|
+
end
|
63
|
+
|
64
|
+
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type)
|
65
|
+
|
66
|
+
# Do the tensor conversion in batch
|
67
|
+
items.each do |key, value|
|
68
|
+
begin
|
69
|
+
if !is_tensor.(value)
|
70
|
+
tensor = as_tensor.(value)
|
71
|
+
|
72
|
+
@data[key] = tensor
|
73
|
+
end
|
74
|
+
rescue
|
75
|
+
if key == :overflowing_values
|
76
|
+
raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths."
|
77
|
+
end
|
78
|
+
raise ArgumentError,
|
79
|
+
"Unable to create tensor, you should probably activate padding " +
|
80
|
+
"with 'padding: true' to have batched tensors with the same length."
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
self
|
85
|
+
end
|
86
|
+
|
87
|
+
def to(*args, **kwargs)
|
88
|
+
new_data = {}
|
89
|
+
device = kwargs[:device]
|
90
|
+
# Check if the args are a device or a dtype
|
91
|
+
if device.nil? && args.length > 0
|
92
|
+
raise Todo
|
93
|
+
end
|
94
|
+
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
95
|
+
items.each do |k, v|
|
96
|
+
# check if v is a floating point
|
97
|
+
if Torch.floating_point?(v)
|
98
|
+
# cast and send to device
|
99
|
+
new_data[k] = v.to(*args, **kwargs)
|
100
|
+
elsif !device.nil?
|
101
|
+
new_data[k] = v.to(device)
|
102
|
+
else
|
103
|
+
new_data[k] = v
|
104
|
+
end
|
105
|
+
end
|
106
|
+
@data = new_data
|
107
|
+
self
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
@@ -0,0 +1,71 @@
|
|
1
|
+
module Transformers
|
2
|
+
module HfHub
|
3
|
+
# Possible values for env variables
|
4
|
+
|
5
|
+
ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"]
|
6
|
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES + ["AUTO"]
|
7
|
+
|
8
|
+
def self._is_true(value)
|
9
|
+
if value.nil?
|
10
|
+
return false
|
11
|
+
end
|
12
|
+
ENV_VARS_TRUE_VALUES.include?(value.upcase)
|
13
|
+
end
|
14
|
+
|
15
|
+
def self._as_int(value)
|
16
|
+
if value.nil?
|
17
|
+
return nil
|
18
|
+
end
|
19
|
+
value.to_i
|
20
|
+
end
|
21
|
+
|
22
|
+
# Constants for file downloads
|
23
|
+
|
24
|
+
DEFAULT_ETAG_TIMEOUT = 10
|
25
|
+
|
26
|
+
# Git-related constants
|
27
|
+
|
28
|
+
DEFAULT_REVISION = "main"
|
29
|
+
|
30
|
+
ENDPOINT = ENV["HF_ENDPOINT"] || "https://huggingface.co"
|
31
|
+
|
32
|
+
HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/%{repo_id}/resolve/%{revision}/%{filename}"
|
33
|
+
HUGGINGFACE_HEADER_X_REPO_COMMIT = "x-repo-commit"
|
34
|
+
HUGGINGFACE_HEADER_X_LINKED_ETAG = "x-linked-etag"
|
35
|
+
HUGGINGFACE_HEADER_X_LINKED_SIZE = "x-linked-size"
|
36
|
+
|
37
|
+
REPO_ID_SEPARATOR = "--"
|
38
|
+
# ^ this substring is not allowed in repo_ids on hf.co
|
39
|
+
# and is the canonical one we use for serialization of repo ids elsewhere.
|
40
|
+
|
41
|
+
REPO_TYPE_DATASET = "dataset"
|
42
|
+
REPO_TYPE_SPACE = "space"
|
43
|
+
REPO_TYPE_MODEL = "model"
|
44
|
+
REPO_TYPES = [nil, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
|
45
|
+
|
46
|
+
REPO_TYPES_URL_PREFIXES = {
|
47
|
+
REPO_TYPE_DATASET => "datasets/",
|
48
|
+
REPO_TYPE_SPACE => "spaces/",
|
49
|
+
}
|
50
|
+
|
51
|
+
# default cache
|
52
|
+
DEFAULT_HOME = File.join(ENV.fetch("HOME"), ".cache")
|
53
|
+
HF_HOME =
|
54
|
+
File.expand_path(
|
55
|
+
ENV.fetch(
|
56
|
+
"HF_HOME",
|
57
|
+
File.join(ENV.fetch("XDG_CACHE_HOME", DEFAULT_HOME), "huggingface")
|
58
|
+
)
|
59
|
+
)
|
60
|
+
|
61
|
+
# New env variables
|
62
|
+
HF_HUB_CACHE = ENV["HF_HUB_CACHE"] || File.join(HF_HOME, "hub")
|
63
|
+
|
64
|
+
HF_HUB_OFFLINE = _is_true(ENV["HF_HUB_OFFLINE"] || ENV["TRANSFORMERS_OFFLINE"])
|
65
|
+
|
66
|
+
# Disable sending the cached token by default is all HTTP requests to the Hub
|
67
|
+
HF_HUB_DISABLE_IMPLICIT_TOKEN = _is_true(ENV["HF_HUB_DISABLE_IMPLICIT_TOKEN"])
|
68
|
+
|
69
|
+
HF_HUB_ENABLE_HF_TRANSFER = _is_true(ENV["HF_HUB_ENABLE_HF_TRANSFER"])
|
70
|
+
end
|
71
|
+
end
|