transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,285 @@
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ class PretrainedConfig
18
+ extend ClassAttribute
19
+
20
+ class_attribute :model_type, ""
21
+ class_attribute :attribute_map, {}
22
+
23
+ # TODO support setter
24
+ def method_missing(m, *args, **kwargs)
25
+ if self.class.attribute_map.include?(m)
26
+ instance_variable_get("@#{self.class.attribute_map[m]}")
27
+ else
28
+ super
29
+ end
30
+ end
31
+
32
+ # TODO support setter
33
+ def respond_to_missing?(m, include_private = true)
34
+ self.class.attribute_map.include?(m) || super
35
+ end
36
+
37
+ attr_reader :output_hidden_states, :output_attentions, :pruned_heads, :tie_word_embeddings, :tokenizer_class,
38
+ :chunk_size_feed_forward, :pad_token_id, :is_decoder, :add_cross_attention,
39
+ :problem_type, :id2label, :architectures, :is_encoder_decoder, :tie_encoder_decoder, :_commit_hash
40
+
41
+ def initialize(**kwargs)
42
+ @return_dict = kwargs.delete(:return_dict) { true }
43
+ @output_hidden_states = kwargs.delete(:output_hidden_states) { false }
44
+ @output_attentions = kwargs.delete(:output_attentions) { false }
45
+ @pruned_heads = kwargs.delete(:pruned_heads) { {} }
46
+ @tie_word_embeddings = kwargs.delete(:tie_word_embeddings) { true }
47
+ @chunk_size_feed_forward = kwargs.delete(:chunk_size_feed_forward) { 0 }
48
+
49
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
50
+ @is_encoder_decoder = kwargs.delete(:is_encoder_decoder) { false }
51
+ @is_decoder = kwargs.delete(:is_decoder) { false }
52
+ @cross_attention_hidden_size = kwargs.delete(:cross_attention_hidden_size)
53
+ @add_cross_attention = kwargs.delete(:add_cross_attention) { false }
54
+ @tie_encoder_decoder = kwargs.delete(:tie_encoder_decoder) { false }
55
+
56
+ # Fine-tuning task arguments
57
+ @architectures = kwargs.delete(:architectures)
58
+ @finetuning_task = kwargs.delete(:finetuning_task)
59
+ @id2label = kwargs.delete(:id2label)
60
+ @label2id = kwargs.delete(:label2id)
61
+ if !@label2id.nil? && !@label2id.is_a?(Hash)
62
+ raise ArgumentError, "Argument label2id should be a dictionary."
63
+ end
64
+ if !@id2label.nil?
65
+ if !@id2label.is_a?(Hash)
66
+ raise ArgumentError, "Argument id2label should be a dictionary."
67
+ end
68
+ num_labels = kwargs.delete(:num_labels)
69
+ if !num_labels.nil? && id2label.length != num_labels
70
+ raise Todo
71
+ end
72
+ @id2label = @id2label.transform_keys(&:to_i)
73
+ # Keys are always strings in JSON so convert ids to int here.
74
+ else
75
+ self.num_labels = kwargs.delete(:num_labels) { 2 }
76
+ end
77
+
78
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
79
+ @tokenizer_class = kwargs.delete(:tokenizer_class)
80
+ @prefix = kwargs.delete(:prefix)
81
+ @bos_token_id = kwargs.delete(:bos_token_id)
82
+ @pad_token_id = kwargs.delete(:pad_token_id)
83
+ @eos_token_id = kwargs.delete(:eos_token_id)
84
+ @sep_token_id = kwargs.delete(:sep_token_id)
85
+
86
+ # regression / multi-label classification
87
+ @problem_type = kwargs.delete(:problem_type)
88
+
89
+ # Name or path to the pretrained checkpoint
90
+ @name_or_path = kwargs.delete(:name_or_path).to_s
91
+ # Config hash
92
+ @commit_hash = kwargs.delete(:_commit_hash)
93
+
94
+ # TODO set kwargs
95
+ @gradient_checkpointing = kwargs[:gradient_checkpointing]
96
+ @output_past = kwargs[:output_past]
97
+ @tie_weights_ = kwargs[:tie_weights_]
98
+ end
99
+
100
+ def name_or_path
101
+ @name_or_path
102
+ end
103
+
104
+ def name_or_path=(value)
105
+ @name_or_path = value.to_s
106
+ end
107
+
108
+ def num_labels
109
+ @id2label.length
110
+ end
111
+
112
+ def num_labels=(num_labels)
113
+ if @id2label.nil? || @id2label.length != num_labels
114
+ @id2label = num_labels.times.to_h { |i| [i, "LABEL_#{i}"] }
115
+ @label2id = @id2label.invert
116
+ end
117
+ end
118
+
119
+ def _attn_implementation
120
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
121
+ if instance_variable_defined?(:@attn_implementation_internal)
122
+ if instance_variable_get(:@attn_implementation_internal).nil?
123
+ # `config.attn_implementation` should never be None, for backward compatibility.
124
+ "eager"
125
+ else
126
+ @attn_implementation_internal
127
+ end
128
+ else
129
+ "eager"
130
+ end
131
+ end
132
+
133
+ def use_return_dict
134
+ @return_dict
135
+ end
136
+
137
+ def to_s
138
+ "#{self.class.name} #{to_json_string}"
139
+ end
140
+
141
+ def to_diff_dict
142
+ config_dict = to_dict
143
+
144
+ # get the default config dict
145
+ default_config_dict = PretrainedConfig.new.to_dict
146
+
147
+ serializable_config_dict = {}
148
+
149
+ config_dict.each do |key, value|
150
+ key = :_name_or_path if key == :name_or_path
151
+ if !default_config_dict.include?(key) || value != default_config_dict[key] || key == :transformers_version
152
+ serializable_config_dict[key] = value
153
+ end
154
+ end
155
+
156
+ serializable_config_dict
157
+ end
158
+
159
+ def _dict
160
+ instance_variables.to_h { |k| [k[1..].to_sym, instance_variable_get(k)] }
161
+ end
162
+
163
+ def to_dict
164
+ output = Copy.deepcopy(_dict)
165
+ output[:model_type] = self.class.model_type
166
+ output.delete(:_auto_class)
167
+ output.delete(:_commit_hash)
168
+ output.delete(:_attn_implementation_internal)
169
+
170
+ # Transformers version when serializing the model
171
+ output[:transformers_version] = VERSION
172
+
173
+ output
174
+ end
175
+
176
+ def to_json_string(use_diff: true)
177
+ if use_diff == true
178
+ config_dict = to_diff_dict
179
+ else
180
+ config_dict = to_dict
181
+ end
182
+ JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
183
+ end
184
+
185
+ class << self
186
+ def from_pretrained(
187
+ pretrained_model_name_or_path,
188
+ cache_dir: nil,
189
+ force_download: false,
190
+ local_files_only: false,
191
+ token: nil,
192
+ revision: "main",
193
+ **kwargs
194
+ )
195
+ config_dict, kwargs = get_config_dict(pretrained_model_name_or_path, **kwargs)
196
+
197
+ from_dict(config_dict, **kwargs)
198
+ end
199
+
200
+ def from_dict(config_dict, **kwargs)
201
+ return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
202
+
203
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
204
+ if kwargs.include?(:_commit_hash) && config_dict.include?(:_commit_hash)
205
+ kwargs[:_commit_hash] = config_dict[:_commit_hash]
206
+ end
207
+
208
+ config = new(**config_dict)
209
+
210
+ kwargs.each do |key, value|
211
+ if config.respond_to?("#{key}=")
212
+ config.public_send("#{key}=", value)
213
+ end
214
+ end
215
+
216
+ Transformers.logger.info("Model config #{config}")
217
+ if return_unused_kwargs
218
+ [config, kwargs]
219
+ else
220
+ config
221
+ end
222
+ end
223
+
224
+ def get_config_dict(pretrained_model_name_or_path, **kwargs)
225
+ # Get config dict associated with the base config file
226
+ config_dict, kwargs = _get_config_dict(pretrained_model_name_or_path, **kwargs)
227
+
228
+ [config_dict, kwargs]
229
+ end
230
+
231
+ private
232
+
233
+ def _get_config_dict(pretrained_model_name_or_path, **kwargs)
234
+ cache_dir = kwargs.delete(:cache_dir)
235
+ force_download = kwargs.delete(:force_download) { false }
236
+ resume_download = kwargs.delete(:resume_download) { false }
237
+ proxies = kwargs.delete(:proxies)
238
+ token = kwargs.delete(:token)
239
+ local_files_only = kwargs.delete(:local_files_only) { false }
240
+ revision = kwargs.delete(:revision)
241
+ _trust_remote_code = kwargs.delete(:trust_remote_code)
242
+ subfolder = kwargs.delete(:subfolder) { "" }
243
+ _from_pipeline = kwargs.delete(:_from_pipeline)
244
+ from_auto_class = kwargs.delete(:_from_auto) { false }
245
+ commit_hash = kwargs.delete(:_commit_hash)
246
+
247
+ user_agent = {file_type: "config", from_auto_class: from_auto_class}
248
+
249
+ is_local = Dir.exist?(pretrained_model_name_or_path)
250
+ configuration_file = kwargs.delete(:_configuration_file) || CONFIG_NAME
251
+
252
+ resolved_config_file = Utils::Hub.cached_file(
253
+ pretrained_model_name_or_path,
254
+ configuration_file,
255
+ cache_dir: cache_dir,
256
+ force_download: force_download,
257
+ proxies: proxies,
258
+ resume_download: resume_download,
259
+ local_files_only: local_files_only,
260
+ token: token,
261
+ user_agent: user_agent,
262
+ revision: revision,
263
+ subfolder: subfolder,
264
+ _commit_hash: commit_hash
265
+ )
266
+ commit_hash = Utils::Hub.extract_commit_hash(resolved_config_file, commit_hash)
267
+
268
+ config_dict = _dict_from_json_file(resolved_config_file)
269
+ config_dict[:_commit_hash] = commit_hash
270
+
271
+ if is_local
272
+ Transformers.logger.info("loading configuration file #{resolved_config_file}")
273
+ else
274
+ Transformers.logger.info("loading configuration file #{configuration_file} from cache at #{resolved_config_file}")
275
+ end
276
+
277
+ [config_dict, kwargs]
278
+ end
279
+
280
+ def _dict_from_json_file(json_file)
281
+ JSON.load_file(json_file).transform_keys(&:to_sym)
282
+ end
283
+ end
284
+ end
285
+ end
@@ -0,0 +1,90 @@
1
+ # Copyright 2018 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module ConvertSlowTokenizer
17
+ class Converter
18
+ def initialize(original_tokenizer)
19
+ @original_tokenizer = original_tokenizer
20
+ end
21
+
22
+ def converted
23
+ raise NotImplementedError
24
+ end
25
+ end
26
+
27
+ class BertConverter < Converter
28
+ def converted
29
+ vocab = @original_tokenizer.vocab
30
+ tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::WordPiece.new(vocab: vocab, unk_token: @original_tokenizer.unk_token.to_s))
31
+
32
+ tokenize_chinese_chars = false
33
+ strip_accents = false
34
+ do_lower_case = false
35
+ if @original_tokenizer.basic_tokenizer
36
+ tokenize_chinese_chars = @original_tokenizer.basic_tokenizer.tokenize_chinese_chars
37
+ strip_accents = @original_tokenizer.basic_tokenizer.strip_accents
38
+ do_lower_case = @original_tokenizer.basic_tokenizer.do_lower_case
39
+ end
40
+
41
+ tokenizer.normalizer =
42
+ Tokenizers::Normalizers::BertNormalizer.new(
43
+ clean_text: true,
44
+ handle_chinese_chars: tokenize_chinese_chars,
45
+ strip_accents: strip_accents,
46
+ lowercase: do_lower_case,
47
+ )
48
+ tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::BertPreTokenizer.new
49
+
50
+ cls = @original_tokenizer.cls_token.to_s
51
+ sep = @original_tokenizer.sep_token.to_s
52
+ cls_token_id = @original_tokenizer.cls_token_id
53
+ sep_token_id = @original_tokenizer.sep_token_id
54
+
55
+ tokenizer.post_processor =
56
+ Tokenizers::Processors::TemplateProcessing.new(
57
+ single: "#{cls}:0 $A:0 #{sep}:0",
58
+ pair: "#{cls}:0 $A:0 #{sep}:0 $B:1 #{sep}:1",
59
+ special_tokens: [
60
+ [cls, cls_token_id],
61
+ [sep, sep_token_id]
62
+ ]
63
+ )
64
+ tokenizer.decoder = Tokenizers::Decoders::WordPiece.new(prefix: "##")
65
+
66
+ tokenizer
67
+ end
68
+ end
69
+
70
+ SLOW_TO_FAST_CONVERTERS = {
71
+ "BertTokenizer" => BertConverter,
72
+ "DistilBertTokenizer" => BertConverter
73
+ }
74
+
75
+ def self.convert_slow_tokenizer(transformer_tokenizer)
76
+ tokenizer_class_name = transformer_tokenizer.class.name.split("::").last
77
+
78
+ if !SLOW_TO_FAST_CONVERTERS.include?(tokenizer_class_name)
79
+ raise ArgumentError,
80
+ "An instance of tokenizer class #{tokenizer_class_name} cannot be converted in a Fast tokenizer instance." +
81
+ " No converter was found. Currently available slow->fast convertors:" +
82
+ " #{SLOW_TO_FAST_CONVERTERS.keys}"
83
+ end
84
+
85
+ converter_class = SLOW_TO_FAST_CONVERTERS.fetch(tokenizer_class_name)
86
+
87
+ converter_class.new(transformer_tokenizer).converted
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,115 @@
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class SquadExample
17
+ attr_reader :question_text, :context_text
18
+
19
+ def initialize(
20
+ qas_id,
21
+ question_text,
22
+ context_text,
23
+ answer_text,
24
+ start_position_character,
25
+ title,
26
+ answers: [],
27
+ is_impossible: false
28
+ )
29
+ @qas_id = qas_id
30
+ @question_text = question_text
31
+ @context_text = context_text
32
+ @answer_text = answer_text
33
+ @title = title
34
+ @is_impossible = is_impossible
35
+ @answers = answers
36
+
37
+ @start_position, @end_position = 0, 0
38
+
39
+ doc_tokens = []
40
+ char_to_word_offset = []
41
+ prev_is_whitespace = true
42
+
43
+ # Split on whitespace so that different tokens may be attributed to their original position.
44
+ @context_text.each_char do |c|
45
+ if _is_whitespace(c)
46
+ prev_is_whitespace = true
47
+ else
48
+ if prev_is_whitespace
49
+ doc_tokens << c
50
+ else
51
+ doc_tokens[-1] += c
52
+ end
53
+ prev_is_whitespace = false
54
+ end
55
+ char_to_word_offset << (doc_tokens.length - 1)
56
+ end
57
+
58
+ @doc_tokens = doc_tokens
59
+ @char_to_word_offset = char_to_word_offset
60
+
61
+ # Start and end positions only has a value during evaluation.
62
+ if !start_position_character.nil? && !is_impossible
63
+ @start_position = char_to_word_offset[start_position_character]
64
+ @end_position = char_to_word_offset[
65
+ [start_position_character + answer_text.length - 1, char_to_word_offset.length - 1].min
66
+ ]
67
+ end
68
+ end
69
+
70
+ def _is_whitespace(c)
71
+ c == " " || c == "\t" || c == "\r" || c == "\n" || c.ord == 0x202F
72
+ end
73
+ end
74
+
75
+ class SquadFeatures
76
+ def initialize(
77
+ input_ids:,
78
+ attention_mask:,
79
+ token_type_ids:,
80
+ cls_index:,
81
+ p_mask:,
82
+ example_index:,
83
+ unique_id:,
84
+ paragraph_len:,
85
+ token_is_max_context:,
86
+ tokens:,
87
+ token_to_orig_map:,
88
+ start_position:,
89
+ end_position:,
90
+ is_impossible:,
91
+ qas_id: nil,
92
+ encoding: nil
93
+ )
94
+ @input_ids = input_ids
95
+ @attention_mask = attention_mask
96
+ @token_type_ids = token_type_ids
97
+ @cls_index = cls_index
98
+ @p_mask = p_mask
99
+
100
+ @example_index = example_index
101
+ @unique_id = unique_id
102
+ @paragraph_len = paragraph_len
103
+ @token_is_max_context = token_is_max_context
104
+ @tokens = tokens
105
+ @token_to_orig_map = token_to_orig_map
106
+
107
+ @start_position = start_position
108
+ @end_position = end_position
109
+ @is_impossible = is_impossible
110
+ @qas_id = qas_id
111
+
112
+ @encoding = encoding
113
+ end
114
+ end
115
+ end
@@ -0,0 +1,25 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DynamicModuleUtils
17
+ # TODO improve
18
+ def self.resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code)
19
+ if trust_remote_code
20
+ raise Error, "trust_remote_code not supported"
21
+ end
22
+ trust_remote_code
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,110 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BatchFeature
17
+ def initialize(data:, tensor_type:)
18
+ @data = data
19
+ convert_to_tensors(tensor_type: tensor_type)
20
+ end
21
+
22
+ def to_h
23
+ @data
24
+ end
25
+
26
+ def [](item)
27
+ @data[item]
28
+ end
29
+
30
+ def keys
31
+ @data.keys
32
+ end
33
+
34
+ def values
35
+ @data.values
36
+ end
37
+
38
+ def items
39
+ @data
40
+ end
41
+
42
+ def _get_is_as_tensor_fns(tensor_type: nil)
43
+ if tensor_type.nil?
44
+ return [nil, nil]
45
+ end
46
+
47
+ as_tensor = lambda do |value|
48
+ if value.is_a?(Array) && value.length > 0 && value[0].is_a?(Numo::NArray)
49
+ value = Numo::NArray.cast(value)
50
+ end
51
+ Torch.tensor(value)
52
+ end
53
+
54
+ is_tensor = Torch.method(:tensor?)
55
+
56
+ [is_tensor, as_tensor]
57
+ end
58
+
59
+ def convert_to_tensors(tensor_type: nil)
60
+ if tensor_type.nil?
61
+ return self
62
+ end
63
+
64
+ is_tensor, as_tensor = _get_is_as_tensor_fns(tensor_type: tensor_type)
65
+
66
+ # Do the tensor conversion in batch
67
+ items.each do |key, value|
68
+ begin
69
+ if !is_tensor.(value)
70
+ tensor = as_tensor.(value)
71
+
72
+ @data[key] = tensor
73
+ end
74
+ rescue
75
+ if key == :overflowing_values
76
+ raise ArgumentError, "Unable to create tensor returning overflowing values of different lengths."
77
+ end
78
+ raise ArgumentError,
79
+ "Unable to create tensor, you should probably activate padding " +
80
+ "with 'padding: true' to have batched tensors with the same length."
81
+ end
82
+ end
83
+
84
+ self
85
+ end
86
+
87
+ def to(*args, **kwargs)
88
+ new_data = {}
89
+ device = kwargs[:device]
90
+ # Check if the args are a device or a dtype
91
+ if device.nil? && args.length > 0
92
+ raise Todo
93
+ end
94
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
95
+ items.each do |k, v|
96
+ # check if v is a floating point
97
+ if Torch.floating_point?(v)
98
+ # cast and send to device
99
+ new_data[k] = v.to(*args, **kwargs)
100
+ elsif !device.nil?
101
+ new_data[k] = v.to(device)
102
+ else
103
+ new_data[k] = v
104
+ end
105
+ end
106
+ @data = new_data
107
+ self
108
+ end
109
+ end
110
+ end
@@ -0,0 +1,71 @@
1
+ module Transformers
2
+ module HfHub
3
+ # Possible values for env variables
4
+
5
+ ENV_VARS_TRUE_VALUES = ["1", "ON", "YES", "TRUE"]
6
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES + ["AUTO"]
7
+
8
+ def self._is_true(value)
9
+ if value.nil?
10
+ return false
11
+ end
12
+ ENV_VARS_TRUE_VALUES.include?(value.upcase)
13
+ end
14
+
15
+ def self._as_int(value)
16
+ if value.nil?
17
+ return nil
18
+ end
19
+ value.to_i
20
+ end
21
+
22
+ # Constants for file downloads
23
+
24
+ DEFAULT_ETAG_TIMEOUT = 10
25
+
26
+ # Git-related constants
27
+
28
+ DEFAULT_REVISION = "main"
29
+
30
+ ENDPOINT = ENV["HF_ENDPOINT"] || "https://huggingface.co"
31
+
32
+ HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/%{repo_id}/resolve/%{revision}/%{filename}"
33
+ HUGGINGFACE_HEADER_X_REPO_COMMIT = "x-repo-commit"
34
+ HUGGINGFACE_HEADER_X_LINKED_ETAG = "x-linked-etag"
35
+ HUGGINGFACE_HEADER_X_LINKED_SIZE = "x-linked-size"
36
+
37
+ REPO_ID_SEPARATOR = "--"
38
+ # ^ this substring is not allowed in repo_ids on hf.co
39
+ # and is the canonical one we use for serialization of repo ids elsewhere.
40
+
41
+ REPO_TYPE_DATASET = "dataset"
42
+ REPO_TYPE_SPACE = "space"
43
+ REPO_TYPE_MODEL = "model"
44
+ REPO_TYPES = [nil, REPO_TYPE_MODEL, REPO_TYPE_DATASET, REPO_TYPE_SPACE]
45
+
46
+ REPO_TYPES_URL_PREFIXES = {
47
+ REPO_TYPE_DATASET => "datasets/",
48
+ REPO_TYPE_SPACE => "spaces/",
49
+ }
50
+
51
+ # default cache
52
+ DEFAULT_HOME = File.join(ENV.fetch("HOME"), ".cache")
53
+ HF_HOME =
54
+ File.expand_path(
55
+ ENV.fetch(
56
+ "HF_HOME",
57
+ File.join(ENV.fetch("XDG_CACHE_HOME", DEFAULT_HOME), "huggingface")
58
+ )
59
+ )
60
+
61
+ # New env variables
62
+ HF_HUB_CACHE = ENV["HF_HUB_CACHE"] || File.join(HF_HOME, "hub")
63
+
64
+ HF_HUB_OFFLINE = _is_true(ENV["HF_HUB_OFFLINE"] || ENV["TRANSFORMERS_OFFLINE"])
65
+
66
+ # Disable sending the cached token by default is all HTTP requests to the Hub
67
+ HF_HUB_DISABLE_IMPLICIT_TOKEN = _is_true(ENV["HF_HUB_DISABLE_IMPLICIT_TOKEN"])
68
+
69
+ HF_HUB_ENABLE_HF_TRANSFER = _is_true(ENV["HF_HUB_ENABLE_HF_TRANSFER"])
70
+ end
71
+ end