transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,285 @@
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ class PretrainedConfig
18
+ extend ClassAttribute
19
+
20
+ class_attribute :model_type, ""
21
+ class_attribute :attribute_map, {}
22
+
23
+ # TODO support setter
24
+ def method_missing(m, *args, **kwargs)
25
+ if self.class.attribute_map.include?(m)
26
+ instance_variable_get("@#{self.class.attribute_map[m]}")
27
+ else
28
+ super
29
+ end
30
+ end
31
+
32
+ # TODO support setter
33
+ def respond_to_missing?(m, include_private = true)
34
+ self.class.attribute_map.include?(m) || super
35
+ end
36
+
37
+ attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
38
+ :chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
39
+ :problem_type, :id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
40
+
41
+ def initialize(**kwargs)
42
+ @return_dict = kwargs.delete(:return_dict) { true }
43
+ @output_hidden_states = kwargs.delete(:output_hidden_states) { false }
44
+ @output_attentions = kwargs.delete(:output_attentions) { false }
45
+ @pruned_heads = kwargs.delete(:pruned_heads) { {} }
46
+ @tie_word_embeddings = kwargs.delete(:tie_word_embeddings) { true }
47
+ @chunk_size_feed_forward = kwargs.delete(:chunk_size_feed_forward) { 0 }
48
+
49
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
50
+ @is_encoder_decoder = kwargs.delete(:is_encoder_decoder) { false }
51
+ @is_decoder = kwargs.delete(:is_decoder) { false }
52
+ @cross_attention_hidden_size = kwargs.delete(:cross_attention_hidden_size)
53
+ @add_cross_attention = kwargs.delete(:add_cross_attention) { false }
54
+ @tie_encoder_decoder = kwargs.delete(:tie_encoder_decoder) { false }
55
+
56
+ # Fine-tuning task arguments
57
+ @architectures = kwargs.delete(:architectures)
58
+ @finetuning_task = kwargs.delete(:finetuning_task)
59
+ @id2label = kwargs.delete(:id2label)
60
+ @label2id = kwargs.delete(:label2id)
61
+ if !@label2id.nil? && !@label2id.is_a?(Hash)
62
+ raise ArgumentError, "Argument label2id should be a dictionary."
63
+ end
64
+ if !@id2label.nil?
65
+ if !@id2label.is_a?(Hash)
66
+ raise ArgumentError, "Argument id2label should be a dictionary."
67
+ end
68
+ num_labels = kwargs.delete(:num_labels)
69
+ if !num_labels.nil? && id2label.length != num_labels
70
+ raise Todo
71
+ end
72
+ @id2label = @id2label.transform_keys(&:to_i)
73
+ # Keys are always strings in JSON so convert ids to int here.
74
+ else
75
+ self.num_labels = kwargs.delete(:num_labels) { 2 }
76
+ end
77
+
78
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
79
+ @tokenizer_class = kwargs.delete(:tokenizer_class)
80
+ @prefix = kwargs.delete(:prefix)
81
+ @bos_token_id = kwargs.delete(:bos_token_id)
82
+ @pad_token_id = kwargs.delete(:pad_token_id)
83
+ @eos_token_id = kwargs.delete(:eos_token_id)
84
+ @sep_token_id = kwargs.delete(:sep_token_id)
85
+
86
+ # regression / multi-label classification
87
+ @problem_type = kwargs.delete(:problem_type)
88
+
89
+ # Name or path to the pretrained checkpoint
90
+ @name_or_path = kwargs.delete(:name_or_path).to_s
91
+ # Config hash
92
+ @commit_hash = kwargs.delete(:_commit_hash)
93
+
94
+ # TODO set kwargs
95
+ @gradient_checkpointing = kwargs[:gradient_checkpointing]
96
+ @output_past = kwargs[:output_past]
97
+ @tie_weights_ = kwargs[:tie_weights_]
98
+ end
99
+
100
+ def name_or_path
101
+ @name_or_path
102
+ end
103
+
104
+ def name_or_path=(value)
105
+ @name_or_path = value.to_s
106
+ end
107
+
108
+ def num_labels
109
+ @id2label.length
110
+ end
111
+
112
+ def num_labels=(num_labels)
113
+ if @id2label.nil? || @id2label.length != num_labels
114
+ @id2label = num_labels.times.to_h { |i| [i, "LABEL_#{i}"] }
115
+ @label2id = @id2label.invert
116
+ end
117
+ end
118
+
119
+ def _attn_implementation
120
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
121
+ if instance_variable_defined?(:@attn_implementation_internal)
122
+ if instance_variable_get(:@attn_implementation_internal).nil?
123
+ # `config.attn_implementation` should never be None, for backward compatibility.
124
+ "eager"
125
+ else
126
+ @attn_implementation_internal
127
+ end
128
+ else
129
+ "eager"
130
+ end
131
+ end
132
+
133
+ def use_return_dict
134
+ @return_dict
135
+ end
136
+
137
+ def to_s
138
+ "#{self.class.name} #{to_json_string}"
139
+ end
140
+
141
+ def to_diff_dict
142
+ config_dict = to_dict
143
+
144
+ # get the default config dict
145
+ default_config_dict = PretrainedConfig.new.to_dict
146
+
147
+ serializable_config_dict = {}
148
+
149
+ config_dict.each do |key, value|
150
+ key = :_name_or_path if key == :name_or_path
151
+ if !default_config_dict.include?(key) || value != default_config_dict[key] || key == :transformers_version
152
+ serializable_config_dict[key] = value
153
+ end
154
+ end
155
+
156
+ serializable_config_dict
157
+ end
158
+
159
+ def _dict
160
+ instance_variables.to_h { |k| [k[1..].to_sym, instance_variable_get(k)] }
161
+ end
162
+
163
+ def to_dict
164
+ output = Copy.deepcopy(_dict)
165
+ output[:model_type] = self.class.model_type
166
+ output.delete(:_auto_class)
167
+ output.delete(:_commit_hash)
168
+ output.delete(:_attn_implementation_internal)
169
+
170
+ # Transformers version when serializing the model
171
+ output[:transformers_version] = VERSION
172
+
173
+ output
174
+ end
175
+
176
+ def to_json_string(use_diff: true)
177
+ if use_diff == true
178
+ config_dict = to_diff_dict
179
+ else
180
+ config_dict = to_dict
181
+ end
182
+ JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
183
+ end
184
+
185
+ class << self
186
+ def from_pretrained(
187
+ pretrained_model_name_or_path,
188
+ cache_dir: nil,
189
+ force_download: false,
190
+ local_files_only: false,
191
+ token: nil,
192
+ revision: "main",
193
+ **kwargs
194
+ )
195
+ config_dict, kwargs = get_config_dict(pretrained_model_name_or_path, **kwargs)
196
+
197
+ from_dict(config_dict, **kwargs)
198
+ end
199
+
200
+ def from_dict(config_dict, **kwargs)
201
+ return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
202
+
203
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
204
+ if kwargs.include?(:_commit_hash) && config_dict.include?(:_commit_hash)
205
+ kwargs[:_commit_hash] = config_dict[:_commit_hash]
206
+ end
207
+
208
+ config = new(**config_dict)
209
+
210
+ kwargs.each do |key, value|
211
+ if config.respond_to?("#{key}=")
212
+ config.public_send("#{key}=", value)
213
+ end
214
+ end
215
+
216
+ Transformers.logger.info("Model config #{config}")
217
+ if return_unused_kwargs
218
+ [config, kwargs]
219
+ else
220
+ config
221
+ end
222
+ end
223
+
224
+ def get_config_dict(pretrained_model_name_or_path, **kwargs)
225
+ # Get config dict associated with the base config file
226
+ config_dict, kwargs = _get_config_dict(pretrained_model_name_or_path, **kwargs)
227
+
228
+ [config_dict, kwargs]
229
+ end
230
+
231
+ private
232
+
233
+ def _get_config_dict(pretrained_model_name_or_path, **kwargs)
234
+ cache_dir = kwargs.delete(:cache_dir)
235
+ force_download = kwargs.delete(:force_download) { false }
236
+ resume_download = kwargs.delete(:resume_download) { false }
237
+ proxies = kwargs.delete(:proxies)
238
+ token = kwargs.delete(:token)
239
+ local_files_only = kwargs.delete(:local_files_only) { false }
240
+ revision = kwargs.delete(:revision)
241
+ _trust_remote_code = kwargs.delete(:trust_remote_code)
242
+ subfolder = kwargs.delete(:subfolder) { "" }
243
+ _from_pipeline = kwargs.delete(:_from_pipeline)
244
+ from_auto_class = kwargs.delete(:_from_auto) { false }
245
+ commit_hash = kwargs.delete(:_commit_hash)
246
+
247
+ user_agent = {file_type: "config", from_auto_class: from_auto_class}
248
+
249
+ is_local = Dir.exist?(pretrained_model_name_or_path)
250
+ configuration_file = kwargs.delete(:_configuration_file) || CONFIG_NAME
251
+
252
+ resolved_config_file = Utils::Hub.cached_file(
253
+ pretrained_model_name_or_path,
254
+ configuration_file,
255
+ cache_dir: cache_dir,
256
+ force_download: force_download,
257
+ proxies: proxies,
258
+ resume_download: resume_download,
259
+ local_files_only: local_files_only,
260
+ token: token,
261
+ user_agent: user_agent,
262
+ revision: revision,
263
+ subfolder: subfolder,
264
+ _commit_hash: commit_hash
265
+ )
266
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
267
+
268
+ config_dict = _dict_from_json_file(resolved_config_file)
269
+ config_dict[:_commit_hash] = commit_hash
270
+
271
+ if is_local
272
+ Transformers.logger.info("loading configuration file #{resolved_config_file}")
273
+ else
274
+ Transformers.logger.info("loading configuration file #{configuration_file} from cache at #{resolved_config_file}")
275
+ end
276
+
277
+ [config_dict, kwargs]
278
+ end
279
+
280
+ def _dict_from_json_file(json_file)
281
+ JSON.load_file(json_file).transform_keys(&:to_sym)
282
+ end
283
+ end
284
+ end
285
+ end
@@ -0,0 +1,90 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module ConvertSlowTokenizer
17
+ class Converter
18
+ def initialize(original_tokenizer)
19
+ @original_tokenizer = original_tokenizer
20
+ end
21
+
22
+ def converted
23
+ raise NotImplementedError
24
+ end
25
+ end
26
+
27
+ class BertConverter < Converter
28
+ def converted
29
+ vocab = @original_tokenizer.vocab
30
+ tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::WordPiece.new(vocab: vocab, unk_token: @original_tokenizer.unk_token.to_s))
31
+
32
+ tokenize_chinese_chars = false
33
+ strip_accents = false
34
+ do_lower_case = false
35
+ if @original_tokenizer.basic_tokenizer
36
+ tokenize_chinese_chars = @original_tokenizer.basic_tokenizer.tokenize_chinese_chars
37
+ strip_accents = @original_tokenizer.basic_tokenizer.strip_accents
38
+ do_lower_case = @original_tokenizer.basic_tokenizer.do_lower_case
39
+ end
40
+
41
+ tokenizer.normalizer =
42
+ Tokenizers::Normalizers::BertNormalizer.new(
43
+ clean_text: true,
44
+ handle_chinese_chars: tokenize_chinese_chars,
45
+ strip_accents: strip_accents,
46
+ lowercase: do_lower_case,
47
+ )
48
+ tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
49
+
50
+ cls = @original_tokenizer.cls_token.to_s
51
+ sep = @original_tokenizer.sep_token.to_s
52
+ cls_token_id = @original_tokenizer.cls_token_id
53
+ sep_token_id = @original_tokenizer.sep_token_id
54
+
55
+ tokenizer.post_processor =
56
+ Tokenizers::Processors::TemplateProcessing.new(
57
+ single: "#{cls}:0 $A:0 #{sep}:0",
58
+ pair: "#{cls}:0 $A:0 #{sep}:0 $B:1 #{sep}:1",
59
+ special_tokens: [
60
+ [cls, cls_token_id],
61
+ [sep, sep_token_id]
62
+ ]
63
+ )
64
+ tokenizer.decoder = Tokenizers::Decoders::WordPiece.new(prefix: "##")
65
+
66
+ tokenizer
67
+ end
68
+ end
69
+
70
+ SLOW_TO_FAST_CONVERTERS = {
71
+ "BertTokenizer" => BertConverter,
72
+ "DistilBertTokenizer" => BertConverter
73
+ }
74
+
75
+ def self.convert_slow_tokenizer(transformer_tokenizer)
76
+ tokenizer_class_name = transformer_tokenizer.class.name.split("::").last
77
+
78
+ if !SLOW_TO_FAST_CONVERTERS.include?(tokenizer_class_name)
79
+ raise ArgumentError,
80
+ "An instance of tokenizer class #{tokenizer_class_name} cannot be converted in a Fast tokenizer instance." +
81
+ " No converter was found. Currently available slow->fast convertors:" +
82
+ " #{SLOW_TO_FAST_CONVERTERS.keys}"
83
+ end
84
+
85
+ converter_class = SLOW_TO_FAST_CONVERTERS.fetch(tokenizer_class_name)
86
+
87
+ converter_class.new(transformer_tokenizer).converted
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,115 @@
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class SquadExample
17
+ attr_reader :question_text, :context_text
18
+
19
+ def initialize(
20
+ qas_id,
21
+ question_text,
22
+ context_text,
23
+ answer_text,
24
+ start_position_character,
25
+ title,
26
+ answers: [],
27
+ is_impossible: false
28
+ )
29
+ @qas_id = qas_id
30
+ @question_text = question_text
31
+ @context_text = context_text
32
+ @answer_text = answer_text
33
+ @title = title
34
+ @is_impossible = is_impossible
35
+ @answers = answers
36
+
37
+ @start_position, @end_position = 0, 0
38
+
39
+ doc_tokens = []
40
+ char_to_word_offset = []
41
+ prev_is_whitespace = true
42
+
43
+ # Split on whitespace so that different tokens may be attributed to their original position.
44
+ @context_text.each_char do |c|
45
+ if _is_whitespace(c)
46
+ prev_is_whitespace = true
47
+ else
48
+ if prev_is_whitespace
49
+ doc_tokens << c
50
+ else
51
+ doc_tokens[-1] += c
52
+ end
53
+ prev_is_whitespace = false
54
+ end
55
+ char_to_word_offset << (doc_tokens.length - 1)
56
+ end
57
+
58
+ @doc_tokens = doc_tokens
59
+ @char_to_word_offset = char_to_word_offset
60
+
61
+ # Start and end positions only has a value during evaluation.
62
+ if !start_position_character.nil? && !is_impossible
63
+ @start_position = char_to_word_offset[start_position_character]
64
+ @end_position = char_to_word_offset[
65
+ [start_position_character + answer_text.length - 1, char_to_word_offset.length - 1].min
66
+ ]
67
+ end
68
+ end
69
+
70
+ def _is_whitespace(c)
71
+ c == " " || c == "\t" || c == "\r" || c == "\n" || c.ord == 0x202F
72
+ end
73
+ end
74
+
75
+ class SquadFeatures
76
+ def initialize(
77
+ input_ids:,
78
+ attention_mask:,
79
+ token_type_ids:,
80
+ cls_index:,
81
+ p_mask:,
82
+ example_index:,
83
+ unique_id:,
84
+ paragraph_len:,
85
+ token_is_max_context:,
86
+ tokens:,
87
+ token_to_orig_map:,
88
+ start_position:,
89
+ end_position:,
90
+ is_impossible:,
91
+ qas_id: nil,
92
+ encoding: nil
93
+ )
94
+ @input_ids = input_ids
95
+ @attention_mask = attention_mask
96
+ @token_type_ids = token_type_ids
97
+ @cls_index = cls_index
98
+ @p_mask = p_mask
99
+
100
+ @example_index = example_index
101
+ @unique_id = unique_id
102
+ @paragraph_len = paragraph_len
103
+ @token_is_max_context = token_is_max_context
104
+ @tokens = tokens
105
+ @token_to_orig_map = token_to_orig_map
106
+
107
+ @start_position = start_position
108
+ @end_position = end_position
109
+ @is_impossible = is_impossible
110
+ @qas_id = qas_id
111
+
112
+ @encoding = encoding
113
+ end
114
+ end
115
+ end
@@ -0,0 +1,25 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DynamicModuleUtils
17
+ # TODO improve
18
+ def self.resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code)
19
+ if trust_remote_code
20
+ raise Error, "trust_remote_code not supported"
21
+ end
22
+ trust_remote_code
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,110 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BatchFeature
17
+ def initialize(data:, tensor_type:)
18
+ @data = data
19
+ convert_to_tensors(tensor_type: tensor_type)
20
+ end
21
+
22
+ def to_h
23
+ @data
24
+ end
25
+
26
+ def [](item)
27
+ @data[item]
28
+ end
29
+
30
+ def keys
31
+ @data.keys
32
+ end
33
+
34
+ def values
35
+ @data.values
36
+ end
37
+
38
+ def items
39
+ @data
40
+ end
41
+
42
+ def _get_is_as_tensor_fns(tensor_type: nil)
43
+ if tensor_type.nil?
44
+ return [nil, nil]
45
+ end
46
+
47
+ as_tensor = lambda do |value|
48
+ if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
49
+ value = Numo::NArray.cast(value)
50
+ end
51
+ Torch.tensor(value)
52
+ end
53
+
54
+ is_tensor = Torch.method(:tensor?)
55
+
56
+ [is_tensor, as_tensor]
57
+ end
58
+
59
+ def convert_to_tensors(tensor_type: nil)
60
+ if tensor_type.nil?
61
+ return self
62
+ end
63
+
64
+ is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type)
65
+
66
+ # Do the tensor conversion in batch
67
+ items.each do |key, value|
68
+ begin
69
+ if !is_tensor.(value)
70
+ tensor = as_tensor.(value)
71
+
72
+ @data[key] = tensor
73
+ end
74
+ rescue
75
+ if key == :overflowing_values
76
+ raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths."
77
+ end
78
+ raise ArgumentError,
79
+ "Unable to create tensor, you should probably activate padding " +
80
+ "with 'padding: true' to have batched tensors with the same length."
81
+ end
82
+ end
83
+
84
+ self
85
+ end
86
+
87
+ def to(*args, **kwargs)
88
+ new_data = {}
89
+ device = kwargs[:device]
90
+ # Check if the args are a device or a dtype
91
+ if device.nil? && args.length > 0
92
+ raise Todo
93
+ end
94
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
95
+ items.each do |k, v|
96
+ # check if v is a floating point
97
+ if Torch.floating_point?(v)
98
+ # cast and send to device
99
+ new_data[k] = v.to(*args, **kwargs)
100
+ elsif !device.nil?
101
+ new_data[k] = v.to(device)
102
+ else
103
+ new_data[k] = v
104
+ end
105
+ end
106
+ @data = new_data
107
+ self
108
+ end
109
+ end
110
+ end
@@ -0,0 +1,71 @@
1
+ module Transformers
2
+ module HfHub
3
+ # Possible values for env variables
4
+
5
+ ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"]
6
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES + ["AUTO"]
7
+
8
+ def self._is_true(value)
9
+ if value.nil?
10
+ return false
11
+ end
12
+ ENV_VARS_TRUE_VALUES.include?(value.upcase)
13
+ end
14
+
15
+ def self._as_int(value)
16
+ if value.nil?
17
+ return nil
18
+ end
19
+ value.to_i
20
+ end
21
+
22
+ # Constants for file downloads
23
+
24
+ DEFAULT_ETAG_TIMEOUT = 10
25
+
26
+ # Git-related constants
27
+
28
+ DEFAULT_REVISION = "main"
29
+
30
+ ENDPOINT = ENV["HF_ENDPOINT"] || "https://huggingface.co"
31
+
32
+ HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/%{repo_id}/resolve/%{revision}/%{filename}"
33
+ HUGGINGFACE_HEADER_X_REPO_COMMIT = "x-repo-commit"
34
+ HUGGINGFACE_HEADER_X_LINKED_ETAG = "x-linked-etag"
35
+ HUGGINGFACE_HEADER_X_LINKED_SIZE = "x-linked-size"
36
+
37
+ REPO_ID_SEPARATOR = "--"
38
+ # ^ this substring is not allowed in repo_ids on hf.co
39
+ # and is the canonical one we use for serialization of repo ids elsewhere.
40
+
41
+ REPO_TYPE_DATASET = "dataset"
42
+ REPO_TYPE_SPACE = "space"
43
+ REPO_TYPE_MODEL = "model"
44
+ REPO_TYPES = [nil, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
45
+
46
+ REPO_TYPES_URL_PREFIXES = {
47
+ REPO_TYPE_DATASET => "datasets/",
48
+ REPO_TYPE_SPACE => "spaces/",
49
+ }
50
+
51
+ # default cache
52
+ DEFAULT_HOME = File.join(ENV.fetch("HOME"), ".cache")
53
+ HF_HOME =
54
+ File.expand_path(
55
+ ENV.fetch(
56
+ "HF_HOME",
57
+ File.join(ENV.fetch("XDG_CACHE_HOME", DEFAULT_HOME), "huggingface")
58
+ )
59
+ )
60
+
61
+ # New env variables
62
+ HF_HUB_CACHE = ENV["HF_HUB_CACHE"] || File.join(HF_HOME, "hub")
63
+
64
+ HF_HUB_OFFLINE = _is_true(ENV["HF_HUB_OFFLINE"] || ENV["TRANSFORMERS_OFFLINE"])
65
+
66
+ # Disable sending the cached token by default is all HTTP requests to the Hub
67
+ HF_HUB_DISABLE_IMPLICIT_TOKEN = _is_true(ENV["HF_HUB_DISABLE_IMPLICIT_TOKEN"])
68
+
69
+ HF_HUB_ENABLE_HF_TRANSFER = _is_true(ENV["HF_HUB_ENABLE_HF_TRANSFER"])
70
+ end
71
+ end