transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,386 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class PreTrainedTokenizerFast < PreTrainedTokenizerBase
|
17
|
+
def initialize(*args, **kwargs)
|
18
|
+
tokenizer_object = kwargs.delete(:tokenizer_object)
|
19
|
+
slow_tokenizer = kwargs.delete(:__slow_tokenizer)
|
20
|
+
fast_tokenizer_file = kwargs.delete(:tokenizer_file)
|
21
|
+
from_slow = kwargs.delete(:from_slow) { false }
|
22
|
+
_added_tokens_decoder = kwargs.delete(:added_tokens_decoder)
|
23
|
+
|
24
|
+
if !tokenizer_object.nil?
|
25
|
+
fast_tokenizer = Copy.deepcopy(tokenizer_object)
|
26
|
+
elsif !fast_tokenizer_file.nil? && !from_slow
|
27
|
+
# We have a serialization from tokenizers which let us directly build the backend
|
28
|
+
fast_tokenizer = Tokenizers::Tokenizer.from_file(fast_tokenizer_file)
|
29
|
+
elsif !slow_tokenizer.nil?
|
30
|
+
# We need to convert a slow tokenizer to build the backend
|
31
|
+
fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
|
32
|
+
elsif !@slow_tokenizer_class.nil?
|
33
|
+
# We need to create and convert a slow tokenizer to build the backend
|
34
|
+
slow_tokenizer = @slow_tokenizer_class.new(*args, **kwargs)
|
35
|
+
fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
|
36
|
+
else
|
37
|
+
raise ArgumentError, <<~MSG
|
38
|
+
Couldn't instantiate the backend tokenizer from one of:
|
39
|
+
(1) a `tokenizers` library serialization file,
|
40
|
+
(2) a slow tokenizer instance to convert or
|
41
|
+
(3) an equivalent slow tokenizer class to instantiate and convert.
|
42
|
+
You need to have sentencepiece installed to convert a slow tokenizer to a fast one.
|
43
|
+
MSG
|
44
|
+
end
|
45
|
+
|
46
|
+
@tokenizer = fast_tokenizer
|
47
|
+
|
48
|
+
if !slow_tokenizer.nil?
|
49
|
+
kwargs.merge!(slow_tokenizer.init_kwargs)
|
50
|
+
end
|
51
|
+
|
52
|
+
@decode_use_source_tokenizer = false
|
53
|
+
|
54
|
+
_truncation = @tokenizer.truncation
|
55
|
+
|
56
|
+
if !_truncation.nil?
|
57
|
+
_truncation = _truncation.transform_keys(&:to_sym)
|
58
|
+
@tokenizer.enable_truncation(_truncation[:max_length], **_truncation.except(:max_length))
|
59
|
+
kwargs[:max_length] ||= _truncation[:max_length]
|
60
|
+
kwargs[:truncation_side] ||= _truncation[:direction]
|
61
|
+
kwargs[:stride] ||= _truncation[:stride]
|
62
|
+
kwargs[:truncation_strategy] ||= _truncation[:strategy]
|
63
|
+
else
|
64
|
+
@tokenizer.no_truncation
|
65
|
+
end
|
66
|
+
|
67
|
+
_padding = @tokenizer.padding
|
68
|
+
if !_padding.nil?
|
69
|
+
_padding = _padding.transform_keys(&:to_sym)
|
70
|
+
@tokenizer.enable_padding(**_padding)
|
71
|
+
kwargs[:pad_token] ||= _padding[:pad_token]
|
72
|
+
kwargs[:pad_token_type_id] ||= _padding[:pad_token_type_id]
|
73
|
+
kwargs[:padding_side] ||= _padding[:direction]
|
74
|
+
kwargs[:max_length] ||= _padding[:length]
|
75
|
+
kwargs[:pad_to_multiple_of] ||= _padding[:pad_to_multiple_of]
|
76
|
+
end
|
77
|
+
|
78
|
+
# We call this after having initialized the backend tokenizer because we update it.
|
79
|
+
super(**kwargs)
|
80
|
+
end
|
81
|
+
|
82
|
+
def is_fast
|
83
|
+
true
|
84
|
+
end
|
85
|
+
|
86
|
+
def get_vocab
|
87
|
+
@tokenizer.vocab(with_added_tokens: true)
|
88
|
+
end
|
89
|
+
|
90
|
+
def vocab
|
91
|
+
get_vocab
|
92
|
+
end
|
93
|
+
|
94
|
+
def convert_tokens_to_ids(tokens)
|
95
|
+
if tokens.nil?
|
96
|
+
return nil
|
97
|
+
end
|
98
|
+
|
99
|
+
if tokens.is_a?(String)
|
100
|
+
return _convert_token_to_id_with_added_voc(tokens)
|
101
|
+
end
|
102
|
+
|
103
|
+
ids = []
|
104
|
+
tokens.each do |token|
|
105
|
+
ids << _convert_token_to_id_with_added_voc(token)
|
106
|
+
end
|
107
|
+
ids
|
108
|
+
end
|
109
|
+
|
110
|
+
def _convert_token_to_id_with_added_voc(token)
|
111
|
+
index = @tokenizer.token_to_id(token)
|
112
|
+
if index.nil?
|
113
|
+
return unk_token_id
|
114
|
+
end
|
115
|
+
index
|
116
|
+
end
|
117
|
+
|
118
|
+
def convert_ids_to_tokens(ids, skip_special_tokens: false)
|
119
|
+
if ids.is_a?(Integer)
|
120
|
+
return @tokenizer.id_to_token(ids)
|
121
|
+
end
|
122
|
+
tokens = []
|
123
|
+
ids.each do |index|
|
124
|
+
index = index.to_i
|
125
|
+
if skip_special_tokens && @all_special_ids.include?(index)
|
126
|
+
next
|
127
|
+
end
|
128
|
+
tokens << @tokenizer.id_to_token(index)
|
129
|
+
end
|
130
|
+
tokens
|
131
|
+
end
|
132
|
+
|
133
|
+
private
|
134
|
+
|
135
|
+
def set_truncation_and_padding(
|
136
|
+
padding_strategy:,
|
137
|
+
truncation_strategy:,
|
138
|
+
max_length:,
|
139
|
+
stride:,
|
140
|
+
pad_to_multiple_of:
|
141
|
+
)
|
142
|
+
_truncation = @tokenizer.truncation
|
143
|
+
_padding = @tokenizer.padding
|
144
|
+
# Set truncation and padding on the backend tokenizer
|
145
|
+
if truncation_strategy == TruncationStrategy::DO_NOT_TRUNCATE
|
146
|
+
if !_truncation.nil?
|
147
|
+
@tokenizer.no_truncation
|
148
|
+
end
|
149
|
+
else
|
150
|
+
target = {
|
151
|
+
max_length: max_length,
|
152
|
+
stride: stride,
|
153
|
+
strategy: truncation_strategy,
|
154
|
+
direction: @truncation_side
|
155
|
+
}
|
156
|
+
|
157
|
+
# _truncation might contain more keys that the target `transformers`
|
158
|
+
# supports. Use only the target keys to trigger `enable_truncation`.
|
159
|
+
# This should enable this code to works on various `tokenizers`
|
160
|
+
# targets.
|
161
|
+
if _truncation.nil?
|
162
|
+
current = nil
|
163
|
+
else
|
164
|
+
current = target.to_h { |k, _| [k, _truncation[k]] }
|
165
|
+
end
|
166
|
+
|
167
|
+
if current != target
|
168
|
+
@tokenizer.enable_truncation(target.delete(:max_length), **target)
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
if padding_strategy == PaddingStrategy::DO_NOT_PAD
|
173
|
+
if !_padding.nil?
|
174
|
+
@tokenizer.no_padding
|
175
|
+
end
|
176
|
+
else
|
177
|
+
length = padding_strategy == PaddingStrategy::MAX_LENGTH ? max_length : nil
|
178
|
+
target = {
|
179
|
+
length: length,
|
180
|
+
direction: @padding_side,
|
181
|
+
pad_id: @pad_token_id,
|
182
|
+
pad_token: @pad_token,
|
183
|
+
pad_type_id: @pad_token_type_id,
|
184
|
+
pad_to_multiple_of: pad_to_multiple_of
|
185
|
+
}
|
186
|
+
if _padding != target
|
187
|
+
@tokenizer.enable_padding(**target)
|
188
|
+
end
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def _batch_encode_plus(
|
193
|
+
batch_text_or_text_pairs,
|
194
|
+
add_special_tokens: true,
|
195
|
+
padding_strategy: PaddingStrategy::DO_NOT_PAD,
|
196
|
+
truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
|
197
|
+
max_length: nil,
|
198
|
+
stride: 0,
|
199
|
+
is_split_into_words: false,
|
200
|
+
pad_to_multiple_of: nil,
|
201
|
+
return_tensors: nil,
|
202
|
+
return_token_type_ids: nil,
|
203
|
+
return_attention_mask: nil,
|
204
|
+
return_overflowing_tokens: false,
|
205
|
+
return_special_tokens_mask: false,
|
206
|
+
return_offsets_mapping: false,
|
207
|
+
return_length: false,
|
208
|
+
verbose: true
|
209
|
+
)
|
210
|
+
if !batch_text_or_text_pairs.is_a?(Array)
|
211
|
+
raise TypeError, "batch_text_or_text_pairs has to be an array (got #{batch_text_or_text_pairs.class.name})"
|
212
|
+
end
|
213
|
+
|
214
|
+
# Set the truncation and padding strategy and restore the initial configuration
|
215
|
+
set_truncation_and_padding(
|
216
|
+
padding_strategy: padding_strategy,
|
217
|
+
truncation_strategy: truncation_strategy,
|
218
|
+
max_length: max_length,
|
219
|
+
stride: stride,
|
220
|
+
pad_to_multiple_of: pad_to_multiple_of
|
221
|
+
)
|
222
|
+
|
223
|
+
encodings =
|
224
|
+
@tokenizer.encode_batch(
|
225
|
+
batch_text_or_text_pairs,
|
226
|
+
add_special_tokens: add_special_tokens,
|
227
|
+
is_pretokenized: is_split_into_words,
|
228
|
+
)
|
229
|
+
|
230
|
+
# Convert encoding to dict
|
231
|
+
# `Tokens` has type: Tuple[
|
232
|
+
# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
|
233
|
+
# List[EncodingFast]
|
234
|
+
# ]
|
235
|
+
# with nested dimensions corresponding to batch, overflows, sequence length
|
236
|
+
tokens_and_encodings =
|
237
|
+
encodings.map do |encoding|
|
238
|
+
_convert_encoding(
|
239
|
+
encoding: encoding,
|
240
|
+
return_token_type_ids: return_token_type_ids,
|
241
|
+
return_attention_mask: return_attention_mask,
|
242
|
+
return_overflowing_tokens: return_overflowing_tokens,
|
243
|
+
return_special_tokens_mask: return_special_tokens_mask,
|
244
|
+
return_offsets_mapping: return_offsets_mapping,
|
245
|
+
return_length: return_length,
|
246
|
+
verbose: verbose
|
247
|
+
)
|
248
|
+
end
|
249
|
+
|
250
|
+
# Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
|
251
|
+
# From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
|
252
|
+
# (we say ~ because the number of overflow varies with the example in the batch)
|
253
|
+
#
|
254
|
+
# To match each overflowing sample with the original sample in the batch
|
255
|
+
# we add an overflow_to_sample_mapping array (see below)
|
256
|
+
sanitized_tokens = {}
|
257
|
+
tokens_and_encodings[0][0].each_key do |key|
|
258
|
+
stack = tokens_and_encodings.map { |item, _| item[key][0] }
|
259
|
+
sanitized_tokens[key] = stack
|
260
|
+
end
|
261
|
+
sanitized_encodings = tokens_and_encodings.map { |_, item| item[0] }
|
262
|
+
|
263
|
+
# If returning overflowing tokens, we need to return a mapping
|
264
|
+
# from the batch idx to the original sample
|
265
|
+
if return_overflowing_tokens
|
266
|
+
overflow_to_sample_mapping = []
|
267
|
+
tokens_and_encodings.each_with_index do |(toks, _), i|
|
268
|
+
overflow_to_sample_mapping += [i] * toks["input_ids"].length
|
269
|
+
end
|
270
|
+
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
|
271
|
+
end
|
272
|
+
|
273
|
+
sanitized_tokens["input_ids"].each do |input_ids|
|
274
|
+
_eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
|
275
|
+
end
|
276
|
+
|
277
|
+
BatchEncoding.new(data: sanitized_tokens, encoding: sanitized_encodings, tensor_type: return_tensors)
|
278
|
+
end
|
279
|
+
|
280
|
+
def _convert_encoding(
|
281
|
+
encoding:,
|
282
|
+
return_token_type_ids: nil,
|
283
|
+
return_attention_mask: nil,
|
284
|
+
return_overflowing_tokens: false,
|
285
|
+
return_special_tokens_mask: false,
|
286
|
+
return_offsets_mapping: false,
|
287
|
+
return_length: false,
|
288
|
+
verbose: true
|
289
|
+
)
|
290
|
+
if return_token_type_ids.nil?
|
291
|
+
return_token_type_ids = self.class.model_input_names.include?("token_type_ids")
|
292
|
+
end
|
293
|
+
if return_attention_mask.nil?
|
294
|
+
return_attention_mask = self.class.model_input_names.include?("attention_mask")
|
295
|
+
end
|
296
|
+
|
297
|
+
if return_overflowing_tokens && !encoding.overflowing.nil?
|
298
|
+
encodings = [encoding] + encoding.overflowing
|
299
|
+
else
|
300
|
+
encodings = [encoding]
|
301
|
+
end
|
302
|
+
|
303
|
+
encoding_dict = Hash.new { |h, k| h[k] = [] }
|
304
|
+
encodings.each do |e|
|
305
|
+
encoding_dict["input_ids"] << e.ids
|
306
|
+
|
307
|
+
if return_token_type_ids
|
308
|
+
encoding_dict["token_type_ids"] << e.type_ids
|
309
|
+
end
|
310
|
+
if return_attention_mask
|
311
|
+
encoding_dict["attention_mask"] << e.attention_mask
|
312
|
+
end
|
313
|
+
if return_special_tokens_mask
|
314
|
+
encoding_dict["special_tokens_mask"] << e.special_tokens_mask
|
315
|
+
end
|
316
|
+
if return_offsets_mapping
|
317
|
+
encoding_dict["offset_mapping"] << e.offsets
|
318
|
+
end
|
319
|
+
if return_length
|
320
|
+
encoding_dict["length"] << e.ids.length
|
321
|
+
end
|
322
|
+
end
|
323
|
+
|
324
|
+
[encoding_dict, encodings]
|
325
|
+
end
|
326
|
+
|
327
|
+
def _encode_plus(
|
328
|
+
text:,
|
329
|
+
text_pair: nil,
|
330
|
+
add_special_tokens: true,
|
331
|
+
padding_strategy: PaddingStrategy::DO_NOT_PAD,
|
332
|
+
truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
|
333
|
+
max_length: nil,
|
334
|
+
stride: 0,
|
335
|
+
is_split_into_words: false,
|
336
|
+
pad_to_multiple_of: nil,
|
337
|
+
return_tensors: nil,
|
338
|
+
return_token_type_ids: nil,
|
339
|
+
return_attention_mask: nil,
|
340
|
+
return_overflowing_tokens: false,
|
341
|
+
return_special_tokens_mask: false,
|
342
|
+
return_offsets_mapping: false,
|
343
|
+
return_length: false,
|
344
|
+
verbose: true,
|
345
|
+
**kwargs
|
346
|
+
)
|
347
|
+
batched_input = text_pair ? [[text, text_pair]] : [text]
|
348
|
+
batched_output =
|
349
|
+
_batch_encode_plus(
|
350
|
+
batched_input,
|
351
|
+
is_split_into_words: is_split_into_words,
|
352
|
+
add_special_tokens: add_special_tokens,
|
353
|
+
padding_strategy: padding_strategy,
|
354
|
+
truncation_strategy: truncation_strategy,
|
355
|
+
max_length: max_length,
|
356
|
+
stride: stride,
|
357
|
+
pad_to_multiple_of: pad_to_multiple_of,
|
358
|
+
return_tensors: return_tensors,
|
359
|
+
return_token_type_ids: return_token_type_ids,
|
360
|
+
return_attention_mask: return_attention_mask,
|
361
|
+
return_overflowing_tokens: return_overflowing_tokens,
|
362
|
+
return_special_tokens_mask: return_special_tokens_mask,
|
363
|
+
return_offsets_mapping: return_offsets_mapping,
|
364
|
+
return_length: return_length,
|
365
|
+
verbose: verbose,
|
366
|
+
**kwargs
|
367
|
+
)
|
368
|
+
|
369
|
+
# Return tensor is None, then we can remove the leading batch axis
|
370
|
+
# Overflowing tokens are returned as a batch of output so we keep them in this case
|
371
|
+
if return_tensors.nil? && !return_overflowing_tokens
|
372
|
+
batched_output =
|
373
|
+
BatchEncoding.new(
|
374
|
+
data: batched_output.items.to_h { |key, value|
|
375
|
+
[key, value.length > 0 && value[0].is_a?(Array) ? value[0] : value]
|
376
|
+
},
|
377
|
+
encoding: batched_output.encodings
|
378
|
+
)
|
379
|
+
end
|
380
|
+
|
381
|
+
_eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
|
382
|
+
|
383
|
+
batched_output
|
384
|
+
end
|
385
|
+
end
|
386
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module TorchUtils
|
17
|
+
def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
|
18
|
+
if chunk_size > 0
|
19
|
+
raise Todo
|
20
|
+
end
|
21
|
+
|
22
|
+
forward_fn.(*input_tensors)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
WEIGHTS_NAME = "pytorch_model.bin"
|
17
|
+
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
18
|
+
TF2_WEIGHTS_NAME = "tf_model.h5"
|
19
|
+
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
20
|
+
TF_WEIGHTS_NAME = "model.ckpt"
|
21
|
+
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
22
|
+
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
23
|
+
SAFE_WEIGHTS_NAME = "model.safetensors"
|
24
|
+
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
25
|
+
CONFIG_NAME = "config.json"
|
26
|
+
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
27
|
+
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
|
28
|
+
PROCESSOR_NAME = "processor_config.json"
|
29
|
+
GENERATION_CONFIG_NAME = "generation_config.json"
|
30
|
+
MODEL_CARD_NAME = "modelcard.json"
|
31
|
+
end
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class ModelOutput
|
17
|
+
def self.attributes
|
18
|
+
@attributes ||= []
|
19
|
+
end
|
20
|
+
|
21
|
+
def self.attribute(attribute)
|
22
|
+
attributes << attribute.to_sym
|
23
|
+
|
24
|
+
define_method(attribute) do
|
25
|
+
self[attribute]
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def initialize(**kwargs)
|
30
|
+
@data = kwargs
|
31
|
+
end
|
32
|
+
|
33
|
+
def [](k)
|
34
|
+
if k.is_a?(String) || k.is_a?(Symbol)
|
35
|
+
@data[k.to_sym]
|
36
|
+
else
|
37
|
+
to_tuple[k]
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def to_tuple
|
42
|
+
self.class.attributes.map { |k| @data[k] }.compact
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
class ExplicitEnum
|
47
|
+
def initialize(value)
|
48
|
+
expected = self.class.constants.map { |k| self.class.const_get(k) }
|
49
|
+
unless expected.include?(value)
|
50
|
+
raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}"
|
51
|
+
end
|
52
|
+
@value = value
|
53
|
+
end
|
54
|
+
|
55
|
+
def to_s
|
56
|
+
@value
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
class PaddingStrategy < ExplicitEnum
|
61
|
+
LONGEST = "longest"
|
62
|
+
MAX_LENGTH = "max_length"
|
63
|
+
DO_NOT_PAD = "do_not_pad"
|
64
|
+
end
|
65
|
+
|
66
|
+
class TensorType < ExplicitEnum
|
67
|
+
PYTORCH = "pt"
|
68
|
+
TENSORFLOW = "tf"
|
69
|
+
NUMPY = "np"
|
70
|
+
JAX = "jax"
|
71
|
+
MLX = "mlx"
|
72
|
+
end
|
73
|
+
|
74
|
+
module Utils
|
75
|
+
def self.infer_framework(model_class)
|
76
|
+
if model_class < Torch::NN::Module
|
77
|
+
"pt"
|
78
|
+
else
|
79
|
+
raise TypeError, "Could not infer framework from class #{model_class}."
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
def self._is_numo(x)
|
84
|
+
x.is_a?(Numo::NArray)
|
85
|
+
end
|
86
|
+
|
87
|
+
def self.is_numo_array(x)
|
88
|
+
_is_numo(x)
|
89
|
+
end
|
90
|
+
|
91
|
+
def self._is_torch(x)
|
92
|
+
x.is_a?(Torch::Tensor)
|
93
|
+
end
|
94
|
+
|
95
|
+
def self.is_torch_tensor(x)
|
96
|
+
_is_torch(x)
|
97
|
+
end
|
98
|
+
|
99
|
+
def self._is_torch_device(x)
|
100
|
+
x.is_a?(Torch::Device)
|
101
|
+
end
|
102
|
+
|
103
|
+
def self.is_torch_device(x)
|
104
|
+
_is_torch_device(x)
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|