transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
class PretrainedConfig
|
18
|
+
extend ClassAttribute
|
19
|
+
|
20
|
+
class_attribute :model_type, ""
|
21
|
+
class_attribute :attribute_map, {}
|
22
|
+
|
23
|
+
# TODO support setter
|
24
|
+
def method_missing(m, *args, **kwargs)
|
25
|
+
if self.class.attribute_map.include?(m)
|
26
|
+
instance_variable_get("@#{self.class.attribute_map[m]}")
|
27
|
+
else
|
28
|
+
super
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
# TODO support setter
|
33
|
+
def respond_to_missing?(m, include_private = true)
|
34
|
+
self.class.attribute_map.include?(m) || super
|
35
|
+
end
|
36
|
+
|
37
|
+
attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
|
38
|
+
:chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
|
39
|
+
:problem_type, :id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
|
40
|
+
|
41
|
+
def initialize(**kwargs)
|
42
|
+
@return_dict = kwargs.delete(:return_dict) { true }
|
43
|
+
@output_hidden_states = kwargs.delete(:output_hidden_states) { false }
|
44
|
+
@output_attentions = kwargs.delete(:output_attentions) { false }
|
45
|
+
@pruned_heads = kwargs.delete(:pruned_heads) { {} }
|
46
|
+
@tie_word_embeddings = kwargs.delete(:tie_word_embeddings) { true }
|
47
|
+
@chunk_size_feed_forward = kwargs.delete(:chunk_size_feed_forward) { 0 }
|
48
|
+
|
49
|
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
50
|
+
@is_encoder_decoder = kwargs.delete(:is_encoder_decoder) { false }
|
51
|
+
@is_decoder = kwargs.delete(:is_decoder) { false }
|
52
|
+
@cross_attention_hidden_size = kwargs.delete(:cross_attention_hidden_size)
|
53
|
+
@add_cross_attention = kwargs.delete(:add_cross_attention) { false }
|
54
|
+
@tie_encoder_decoder = kwargs.delete(:tie_encoder_decoder) { false }
|
55
|
+
|
56
|
+
# Fine-tuning task arguments
|
57
|
+
@architectures = kwargs.delete(:architectures)
|
58
|
+
@finetuning_task = kwargs.delete(:finetuning_task)
|
59
|
+
@id2label = kwargs.delete(:id2label)
|
60
|
+
@label2id = kwargs.delete(:label2id)
|
61
|
+
if !@label2id.nil? && !@label2id.is_a?(Hash)
|
62
|
+
raise ArgumentError, "Argument label2id should be a dictionary."
|
63
|
+
end
|
64
|
+
if !@id2label.nil?
|
65
|
+
if !@id2label.is_a?(Hash)
|
66
|
+
raise ArgumentError, "Argument id2label should be a dictionary."
|
67
|
+
end
|
68
|
+
num_labels = kwargs.delete(:num_labels)
|
69
|
+
if !num_labels.nil? && id2label.length != num_labels
|
70
|
+
raise Todo
|
71
|
+
end
|
72
|
+
@id2label = @id2label.transform_keys(&:to_i)
|
73
|
+
# Keys are always strings in JSON so convert ids to int here.
|
74
|
+
else
|
75
|
+
self.num_labels = kwargs.delete(:num_labels) { 2 }
|
76
|
+
end
|
77
|
+
|
78
|
+
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
79
|
+
@tokenizer_class = kwargs.delete(:tokenizer_class)
|
80
|
+
@prefix = kwargs.delete(:prefix)
|
81
|
+
@bos_token_id = kwargs.delete(:bos_token_id)
|
82
|
+
@pad_token_id = kwargs.delete(:pad_token_id)
|
83
|
+
@eos_token_id = kwargs.delete(:eos_token_id)
|
84
|
+
@sep_token_id = kwargs.delete(:sep_token_id)
|
85
|
+
|
86
|
+
# regression / multi-label classification
|
87
|
+
@problem_type = kwargs.delete(:problem_type)
|
88
|
+
|
89
|
+
# Name or path to the pretrained checkpoint
|
90
|
+
@name_or_path = kwargs.delete(:name_or_path).to_s
|
91
|
+
# Config hash
|
92
|
+
@commit_hash = kwargs.delete(:_commit_hash)
|
93
|
+
|
94
|
+
# TODO set kwargs
|
95
|
+
@gradient_checkpointing = kwargs[:gradient_checkpointing]
|
96
|
+
@output_past = kwargs[:output_past]
|
97
|
+
@tie_weights_ = kwargs[:tie_weights_]
|
98
|
+
end
|
99
|
+
|
100
|
+
def name_or_path
|
101
|
+
@name_or_path
|
102
|
+
end
|
103
|
+
|
104
|
+
def name_or_path=(value)
|
105
|
+
@name_or_path = value.to_s
|
106
|
+
end
|
107
|
+
|
108
|
+
def num_labels
|
109
|
+
@id2label.length
|
110
|
+
end
|
111
|
+
|
112
|
+
def num_labels=(num_labels)
|
113
|
+
if @id2label.nil? || @id2label.length != num_labels
|
114
|
+
@id2label = num_labels.times.to_h { |i| [i, "LABEL_#{i}"] }
|
115
|
+
@label2id = @id2label.invert
|
116
|
+
end
|
117
|
+
end
|
118
|
+
|
119
|
+
def _attn_implementation
|
120
|
+
# This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
|
121
|
+
if instance_variable_defined?(:@attn_implementation_internal)
|
122
|
+
if instance_variable_get(:@attn_implementation_internal).nil?
|
123
|
+
# `config.attn_implementation` should never be None, for backward compatibility.
|
124
|
+
"eager"
|
125
|
+
else
|
126
|
+
@attn_implementation_internal
|
127
|
+
end
|
128
|
+
else
|
129
|
+
"eager"
|
130
|
+
end
|
131
|
+
end
|
132
|
+
|
133
|
+
def use_return_dict
|
134
|
+
@return_dict
|
135
|
+
end
|
136
|
+
|
137
|
+
def to_s
|
138
|
+
"#{self.class.name} #{to_json_string}"
|
139
|
+
end
|
140
|
+
|
141
|
+
def to_diff_dict
|
142
|
+
config_dict = to_dict
|
143
|
+
|
144
|
+
# get the default config dict
|
145
|
+
default_config_dict = PretrainedConfig.new.to_dict
|
146
|
+
|
147
|
+
serializable_config_dict = {}
|
148
|
+
|
149
|
+
config_dict.each do |key, value|
|
150
|
+
key = :_name_or_path if key == :name_or_path
|
151
|
+
if !default_config_dict.include?(key) || value != default_config_dict[key] || key == :transformers_version
|
152
|
+
serializable_config_dict[key] = value
|
153
|
+
end
|
154
|
+
end
|
155
|
+
|
156
|
+
serializable_config_dict
|
157
|
+
end
|
158
|
+
|
159
|
+
def _dict
|
160
|
+
instance_variables.to_h { |k| [k[1..].to_sym, instance_variable_get(k)] }
|
161
|
+
end
|
162
|
+
|
163
|
+
def to_dict
|
164
|
+
output = Copy.deepcopy(_dict)
|
165
|
+
output[:model_type] = self.class.model_type
|
166
|
+
output.delete(:_auto_class)
|
167
|
+
output.delete(:_commit_hash)
|
168
|
+
output.delete(:_attn_implementation_internal)
|
169
|
+
|
170
|
+
# Transformers version when serializing the model
|
171
|
+
output[:transformers_version] = VERSION
|
172
|
+
|
173
|
+
output
|
174
|
+
end
|
175
|
+
|
176
|
+
def to_json_string(use_diff: true)
|
177
|
+
if use_diff == true
|
178
|
+
config_dict = to_diff_dict
|
179
|
+
else
|
180
|
+
config_dict = to_dict
|
181
|
+
end
|
182
|
+
JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
|
183
|
+
end
|
184
|
+
|
185
|
+
class << self
|
186
|
+
def from_pretrained(
|
187
|
+
pretrained_model_name_or_path,
|
188
|
+
cache_dir: nil,
|
189
|
+
force_download: false,
|
190
|
+
local_files_only: false,
|
191
|
+
token: nil,
|
192
|
+
revision: "main",
|
193
|
+
**kwargs
|
194
|
+
)
|
195
|
+
config_dict, kwargs = get_config_dict(pretrained_model_name_or_path, **kwargs)
|
196
|
+
|
197
|
+
from_dict(config_dict, **kwargs)
|
198
|
+
end
|
199
|
+
|
200
|
+
def from_dict(config_dict, **kwargs)
|
201
|
+
return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
|
202
|
+
|
203
|
+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
204
|
+
if kwargs.include?(:_commit_hash) && config_dict.include?(:_commit_hash)
|
205
|
+
kwargs[:_commit_hash] = config_dict[:_commit_hash]
|
206
|
+
end
|
207
|
+
|
208
|
+
config = new(**config_dict)
|
209
|
+
|
210
|
+
kwargs.each do |key, value|
|
211
|
+
if config.respond_to?("#{key}=")
|
212
|
+
config.public_send("#{key}=", value)
|
213
|
+
end
|
214
|
+
end
|
215
|
+
|
216
|
+
Transformers.logger.info("Model config #{config}")
|
217
|
+
if return_unused_kwargs
|
218
|
+
[config, kwargs]
|
219
|
+
else
|
220
|
+
config
|
221
|
+
end
|
222
|
+
end
|
223
|
+
|
224
|
+
def get_config_dict(pretrained_model_name_or_path, **kwargs)
|
225
|
+
# Get config dict associated with the base config file
|
226
|
+
config_dict, kwargs = _get_config_dict(pretrained_model_name_or_path, **kwargs)
|
227
|
+
|
228
|
+
[config_dict, kwargs]
|
229
|
+
end
|
230
|
+
|
231
|
+
private
|
232
|
+
|
233
|
+
def _get_config_dict(pretrained_model_name_or_path, **kwargs)
|
234
|
+
cache_dir = kwargs.delete(:cache_dir)
|
235
|
+
force_download = kwargs.delete(:force_download) { false }
|
236
|
+
resume_download = kwargs.delete(:resume_download) { false }
|
237
|
+
proxies = kwargs.delete(:proxies)
|
238
|
+
token = kwargs.delete(:token)
|
239
|
+
local_files_only = kwargs.delete(:local_files_only) { false }
|
240
|
+
revision = kwargs.delete(:revision)
|
241
|
+
_trust_remote_code = kwargs.delete(:trust_remote_code)
|
242
|
+
subfolder = kwargs.delete(:subfolder) { "" }
|
243
|
+
_from_pipeline = kwargs.delete(:_from_pipeline)
|
244
|
+
from_auto_class = kwargs.delete(:_from_auto) { false }
|
245
|
+
commit_hash = kwargs.delete(:_commit_hash)
|
246
|
+
|
247
|
+
user_agent = {file_type: "config", from_auto_class: from_auto_class}
|
248
|
+
|
249
|
+
is_local = Dir.exist?(pretrained_model_name_or_path)
|
250
|
+
configuration_file = kwargs.delete(:_configuration_file) || CONFIG_NAME
|
251
|
+
|
252
|
+
resolved_config_file = Utils::Hub.cached_file(
|
253
|
+
pretrained_model_name_or_path,
|
254
|
+
configuration_file,
|
255
|
+
cache_dir: cache_dir,
|
256
|
+
force_download: force_download,
|
257
|
+
proxies: proxies,
|
258
|
+
resume_download: resume_download,
|
259
|
+
local_files_only: local_files_only,
|
260
|
+
token: token,
|
261
|
+
user_agent: user_agent,
|
262
|
+
revision: revision,
|
263
|
+
subfolder: subfolder,
|
264
|
+
_commit_hash: commit_hash
|
265
|
+
)
|
266
|
+
commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
|
267
|
+
|
268
|
+
config_dict = _dict_from_json_file(resolved_config_file)
|
269
|
+
config_dict[:_commit_hash] = commit_hash
|
270
|
+
|
271
|
+
if is_local
|
272
|
+
Transformers.logger.info("loading configuration file #{resolved_config_file}")
|
273
|
+
else
|
274
|
+
Transformers.logger.info("loading configuration file #{configuration_file} from cache at #{resolved_config_file}")
|
275
|
+
end
|
276
|
+
|
277
|
+
[config_dict, kwargs]
|
278
|
+
end
|
279
|
+
|
280
|
+
def _dict_from_json_file(json_file)
|
281
|
+
JSON.load_file(json_file).transform_keys(&:to_sym)
|
282
|
+
end
|
283
|
+
end
|
284
|
+
end
|
285
|
+
end
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright 2018 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module ConvertSlowTokenizer
|
17
|
+
class Converter
|
18
|
+
def initialize(original_tokenizer)
|
19
|
+
@original_tokenizer = original_tokenizer
|
20
|
+
end
|
21
|
+
|
22
|
+
def converted
|
23
|
+
raise NotImplementedError
|
24
|
+
end
|
25
|
+
end
|
26
|
+
|
27
|
+
class BertConverter < Converter
|
28
|
+
def converted
|
29
|
+
vocab = @original_tokenizer.vocab
|
30
|
+
tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::WordPiece.new(vocab: vocab, unk_token: @original_tokenizer.unk_token.to_s))
|
31
|
+
|
32
|
+
tokenize_chinese_chars = false
|
33
|
+
strip_accents = false
|
34
|
+
do_lower_case = false
|
35
|
+
if @original_tokenizer.basic_tokenizer
|
36
|
+
tokenize_chinese_chars = @original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
37
|
+
strip_accents = @original_tokenizer.basic_tokenizer.strip_accents
|
38
|
+
do_lower_case = @original_tokenizer.basic_tokenizer.do_lower_case
|
39
|
+
end
|
40
|
+
|
41
|
+
tokenizer.normalizer =
|
42
|
+
Tokenizers::Normalizers::BertNormalizer.new(
|
43
|
+
clean_text: true,
|
44
|
+
handle_chinese_chars: tokenize_chinese_chars,
|
45
|
+
strip_accents: strip_accents,
|
46
|
+
lowercase: do_lower_case,
|
47
|
+
)
|
48
|
+
tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
|
49
|
+
|
50
|
+
cls = @original_tokenizer.cls_token.to_s
|
51
|
+
sep = @original_tokenizer.sep_token.to_s
|
52
|
+
cls_token_id = @original_tokenizer.cls_token_id
|
53
|
+
sep_token_id = @original_tokenizer.sep_token_id
|
54
|
+
|
55
|
+
tokenizer.post_processor =
|
56
|
+
Tokenizers::Processors::TemplateProcessing.new(
|
57
|
+
single: "#{cls}:0 $A:0 #{sep}:0",
|
58
|
+
pair: "#{cls}:0 $A:0 #{sep}:0 $B:1 #{sep}:1",
|
59
|
+
special_tokens: [
|
60
|
+
[cls, cls_token_id],
|
61
|
+
[sep, sep_token_id]
|
62
|
+
]
|
63
|
+
)
|
64
|
+
tokenizer.decoder = Tokenizers::Decoders::WordPiece.new(prefix: "##")
|
65
|
+
|
66
|
+
tokenizer
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
SLOW_TO_FAST_CONVERTERS = {
|
71
|
+
"BertTokenizer" => BertConverter,
|
72
|
+
"DistilBertTokenizer" => BertConverter
|
73
|
+
}
|
74
|
+
|
75
|
+
def self.convert_slow_tokenizer(transformer_tokenizer)
|
76
|
+
tokenizer_class_name = transformer_tokenizer.class.name.split("::").last
|
77
|
+
|
78
|
+
if !SLOW_TO_FAST_CONVERTERS.include?(tokenizer_class_name)
|
79
|
+
raise ArgumentError,
|
80
|
+
"An instance of tokenizer class #{tokenizer_class_name} cannot be converted in a Fast tokenizer instance." +
|
81
|
+
" No converter was found. Currently available slow->fast convertors:" +
|
82
|
+
" #{SLOW_TO_FAST_CONVERTERS.keys}"
|
83
|
+
end
|
84
|
+
|
85
|
+
converter_class = SLOW_TO_FAST_CONVERTERS.fetch(tokenizer_class_name)
|
86
|
+
|
87
|
+
converter_class.new(transformer_tokenizer).converted
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class SquadExample
|
17
|
+
attr_reader :question_text, :context_text
|
18
|
+
|
19
|
+
def initialize(
|
20
|
+
qas_id,
|
21
|
+
question_text,
|
22
|
+
context_text,
|
23
|
+
answer_text,
|
24
|
+
start_position_character,
|
25
|
+
title,
|
26
|
+
answers: [],
|
27
|
+
is_impossible: false
|
28
|
+
)
|
29
|
+
@qas_id = qas_id
|
30
|
+
@question_text = question_text
|
31
|
+
@context_text = context_text
|
32
|
+
@answer_text = answer_text
|
33
|
+
@title = title
|
34
|
+
@is_impossible = is_impossible
|
35
|
+
@answers = answers
|
36
|
+
|
37
|
+
@start_position, @end_position = 0, 0
|
38
|
+
|
39
|
+
doc_tokens = []
|
40
|
+
char_to_word_offset = []
|
41
|
+
prev_is_whitespace = true
|
42
|
+
|
43
|
+
# Split on whitespace so that different tokens may be attributed to their original position.
|
44
|
+
@context_text.each_char do |c|
|
45
|
+
if _is_whitespace(c)
|
46
|
+
prev_is_whitespace = true
|
47
|
+
else
|
48
|
+
if prev_is_whitespace
|
49
|
+
doc_tokens << c
|
50
|
+
else
|
51
|
+
doc_tokens[-1] += c
|
52
|
+
end
|
53
|
+
prev_is_whitespace = false
|
54
|
+
end
|
55
|
+
char_to_word_offset << (doc_tokens.length - 1)
|
56
|
+
end
|
57
|
+
|
58
|
+
@doc_tokens = doc_tokens
|
59
|
+
@char_to_word_offset = char_to_word_offset
|
60
|
+
|
61
|
+
# Start and end positions only has a value during evaluation.
|
62
|
+
if !start_position_character.nil? && !is_impossible
|
63
|
+
@start_position = char_to_word_offset[start_position_character]
|
64
|
+
@end_position = char_to_word_offset[
|
65
|
+
[start_position_character + answer_text.length - 1, char_to_word_offset.length - 1].min
|
66
|
+
]
|
67
|
+
end
|
68
|
+
end
|
69
|
+
|
70
|
+
def _is_whitespace(c)
|
71
|
+
c == " " || c == "\t" || c == "\r" || c == "\n" || c.ord == 0x202F
|
72
|
+
end
|
73
|
+
end
|
74
|
+
|
75
|
+
class SquadFeatures
|
76
|
+
def initialize(
|
77
|
+
input_ids:,
|
78
|
+
attention_mask:,
|
79
|
+
token_type_ids:,
|
80
|
+
cls_index:,
|
81
|
+
p_mask:,
|
82
|
+
example_index:,
|
83
|
+
unique_id:,
|
84
|
+
paragraph_len:,
|
85
|
+
token_is_max_context:,
|
86
|
+
tokens:,
|
87
|
+
token_to_orig_map:,
|
88
|
+
start_position:,
|
89
|
+
end_position:,
|
90
|
+
is_impossible:,
|
91
|
+
qas_id: nil,
|
92
|
+
encoding: nil
|
93
|
+
)
|
94
|
+
@input_ids = input_ids
|
95
|
+
@attention_mask = attention_mask
|
96
|
+
@token_type_ids = token_type_ids
|
97
|
+
@cls_index = cls_index
|
98
|
+
@p_mask = p_mask
|
99
|
+
|
100
|
+
@example_index = example_index
|
101
|
+
@unique_id = unique_id
|
102
|
+
@paragraph_len = paragraph_len
|
103
|
+
@token_is_max_context = token_is_max_context
|
104
|
+
@tokens = tokens
|
105
|
+
@token_to_orig_map = token_to_orig_map
|
106
|
+
|
107
|
+
@start_position = start_position
|
108
|
+
@end_position = end_position
|
109
|
+
@is_impossible = is_impossible
|
110
|
+
@qas_id = qas_id
|
111
|
+
|
112
|
+
@encoding = encoding
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module DynamicModuleUtils
|
17
|
+
# TODO improve
|
18
|
+
def self.resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code)
|
19
|
+
if trust_remote_code
|
20
|
+
raise Error, "trust_remote_code not supported"
|
21
|
+
end
|
22
|
+
trust_remote_code
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BatchFeature
|
17
|
+
def initialize(data:, tensor_type:)
|
18
|
+
@data = data
|
19
|
+
convert_to_tensors(tensor_type: tensor_type)
|
20
|
+
end
|
21
|
+
|
22
|
+
def to_h
|
23
|
+
@data
|
24
|
+
end
|
25
|
+
|
26
|
+
def [](item)
|
27
|
+
@data[item]
|
28
|
+
end
|
29
|
+
|
30
|
+
def keys
|
31
|
+
@data.keys
|
32
|
+
end
|
33
|
+
|
34
|
+
def values
|
35
|
+
@data.values
|
36
|
+
end
|
37
|
+
|
38
|
+
def items
|
39
|
+
@data
|
40
|
+
end
|
41
|
+
|
42
|
+
def _get_is_as_tensor_fns(tensor_type: nil)
|
43
|
+
if tensor_type.nil?
|
44
|
+
return [nil, nil]
|
45
|
+
end
|
46
|
+
|
47
|
+
as_tensor = lambda do |value|
|
48
|
+
if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
|
49
|
+
value = Numo::NArray.cast(value)
|
50
|
+
end
|
51
|
+
Torch.tensor(value)
|
52
|
+
end
|
53
|
+
|
54
|
+
is_tensor = Torch.method(:tensor?)
|
55
|
+
|
56
|
+
[is_tensor, as_tensor]
|
57
|
+
end
|
58
|
+
|
59
|
+
def convert_to_tensors(tensor_type: nil)
|
60
|
+
if tensor_type.nil?
|
61
|
+
return self
|
62
|
+
end
|
63
|
+
|
64
|
+
is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type)
|
65
|
+
|
66
|
+
# Do the tensor conversion in batch
|
67
|
+
items.each do |key, value|
|
68
|
+
begin
|
69
|
+
if !is_tensor.(value)
|
70
|
+
tensor = as_tensor.(value)
|
71
|
+
|
72
|
+
@data[key] = tensor
|
73
|
+
end
|
74
|
+
rescue
|
75
|
+
if key == :overflowing_values
|
76
|
+
raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths."
|
77
|
+
end
|
78
|
+
raise ArgumentError,
|
79
|
+
"Unable to create tensor, you should probably activate padding " +
|
80
|
+
"with 'padding: true' to have batched tensors with the same length."
|
81
|
+
end
|
82
|
+
end
|
83
|
+
|
84
|
+
self
|
85
|
+
end
|
86
|
+
|
87
|
+
def to(*args, **kwargs)
|
88
|
+
new_data = {}
|
89
|
+
device = kwargs[:device]
|
90
|
+
# Check if the args are a device or a dtype
|
91
|
+
if device.nil? && args.length > 0
|
92
|
+
raise Todo
|
93
|
+
end
|
94
|
+
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
95
|
+
items.each do |k, v|
|
96
|
+
# check if v is a floating point
|
97
|
+
if Torch.floating_point?(v)
|
98
|
+
# cast and send to device
|
99
|
+
new_data[k] = v.to(*args, **kwargs)
|
100
|
+
elsif !device.nil?
|
101
|
+
new_data[k] = v.to(device)
|
102
|
+
else
|
103
|
+
new_data[k] = v
|
104
|
+
end
|
105
|
+
end
|
106
|
+
@data = new_data
|
107
|
+
self
|
108
|
+
end
|
109
|
+
end
|
110
|
+
end
|
@@ -0,0 +1,71 @@
|
|
1
|
+
module Transformers
|
2
|
+
module HfHub
|
3
|
+
# Possible values for env variables
|
4
|
+
|
5
|
+
ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"]
|
6
|
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES + ["AUTO"]
|
7
|
+
|
8
|
+
def self._is_true(value)
|
9
|
+
if value.nil?
|
10
|
+
return false
|
11
|
+
end
|
12
|
+
ENV_VARS_TRUE_VALUES.include?(value.upcase)
|
13
|
+
end
|
14
|
+
|
15
|
+
def self._as_int(value)
|
16
|
+
if value.nil?
|
17
|
+
return nil
|
18
|
+
end
|
19
|
+
value.to_i
|
20
|
+
end
|
21
|
+
|
22
|
+
# Constants for file downloads
|
23
|
+
|
24
|
+
DEFAULT_ETAG_TIMEOUT = 10
|
25
|
+
|
26
|
+
# Git-related constants
|
27
|
+
|
28
|
+
DEFAULT_REVISION = "main"
|
29
|
+
|
30
|
+
ENDPOINT = ENV["HF_ENDPOINT"] || "https://huggingface.co"
|
31
|
+
|
32
|
+
HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/%{repo_id}/resolve/%{revision}/%{filename}"
|
33
|
+
HUGGINGFACE_HEADER_X_REPO_COMMIT = "x-repo-commit"
|
34
|
+
HUGGINGFACE_HEADER_X_LINKED_ETAG = "x-linked-etag"
|
35
|
+
HUGGINGFACE_HEADER_X_LINKED_SIZE = "x-linked-size"
|
36
|
+
|
37
|
+
REPO_ID_SEPARATOR = "--"
|
38
|
+
# ^ this substring is not allowed in repo_ids on hf.co
|
39
|
+
# and is the canonical one we use for serialization of repo ids elsewhere.
|
40
|
+
|
41
|
+
REPO_TYPE_DATASET = "dataset"
|
42
|
+
REPO_TYPE_SPACE = "space"
|
43
|
+
REPO_TYPE_MODEL = "model"
|
44
|
+
REPO_TYPES = [nil, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
|
45
|
+
|
46
|
+
REPO_TYPES_URL_PREFIXES = {
|
47
|
+
REPO_TYPE_DATASET => "datasets/",
|
48
|
+
REPO_TYPE_SPACE => "spaces/",
|
49
|
+
}
|
50
|
+
|
51
|
+
# default cache
|
52
|
+
DEFAULT_HOME = File.join(ENV.fetch("HOME"), ".cache")
|
53
|
+
HF_HOME =
|
54
|
+
File.expand_path(
|
55
|
+
ENV.fetch(
|
56
|
+
"HF_HOME",
|
57
|
+
File.join(ENV.fetch("XDG_CACHE_HOME", DEFAULT_HOME), "huggingface")
|
58
|
+
)
|
59
|
+
)
|
60
|
+
|
61
|
+
# New env variables
|
62
|
+
HF_HUB_CACHE = ENV["HF_HUB_CACHE"] || File.join(HF_HOME, "hub")
|
63
|
+
|
64
|
+
HF_HUB_OFFLINE = _is_true(ENV["HF_HUB_OFFLINE"] || ENV["TRANSFORMERS_OFFLINE"])
|
65
|
+
|
66
|
+
# Disable sending the cached token by default is all HTTP requests to the Hub
|
67
|
+
HF_HUB_DISABLE_IMPLICIT_TOKEN = _is_true(ENV["HF_HUB_DISABLE_IMPLICIT_TOKEN"])
|
68
|
+
|
69
|
+
HF_HUB_ENABLE_HF_TRANSFER = _is_true(ENV["HF_HUB_ENABLE_HF_TRANSFER"])
|
70
|
+
end
|
71
|
+
end
|