transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,386 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class PreTrainedTokenizerFast < PreTrainedTokenizerBase
|
17
|
+
def initialize(*args, **kwargs)
|
18
|
+
tokenizer_object = kwargs.delete(:tokenizer_object)
|
19
|
+
slow_tokenizer = kwargs.delete(:__slow_tokenizer)
|
20
|
+
fast_tokenizer_file = kwargs.delete(:tokenizer_file)
|
21
|
+
from_slow = kwargs.delete(:from_slow) { false }
|
22
|
+
_added_tokens_decoder = kwargs.delete(:added_tokens_decoder)
|
23
|
+
|
24
|
+
if !tokenizer_object.nil?
|
25
|
+
fast_tokenizer = Copy.deepcopy(tokenizer_object)
|
26
|
+
elsif !fast_tokenizer_file.nil? && !from_slow
|
27
|
+
# We have a serialization from tokenizers which let us directly build the backend
|
28
|
+
fast_tokenizer = Tokenizers::Tokenizer.from_file(fast_tokenizer_file)
|
29
|
+
elsif !slow_tokenizer.nil?
|
30
|
+
# We need to convert a slow tokenizer to build the backend
|
31
|
+
fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
|
32
|
+
elsif !@slow_tokenizer_class.nil?
|
33
|
+
# We need to create and convert a slow tokenizer to build the backend
|
34
|
+
slow_tokenizer = @slow_tokenizer_class.new(*args, **kwargs)
|
35
|
+
fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
|
36
|
+
else
|
37
|
+
raise ArgumentError, <<~MSG
|
38
|
+
Couldn't instantiate the backend tokenizer from one of:
|
39
|
+
(1) a `tokenizers` library serialization file,
|
40
|
+
(2) a slow tokenizer instance to convert or
|
41
|
+
(3) an equivalent slow tokenizer class to instantiate and convert.
|
42
|
+
You need to have sentencepiece installed to convert a slow tokenizer to a fast one.
|
43
|
+
MSG
|
44
|
+
end
|
45
|
+
|
46
|
+
@tokenizer = fast_tokenizer
|
47
|
+
|
48
|
+
if !slow_tokenizer.nil?
|
49
|
+
kwargs.merge!(slow_tokenizer.init_kwargs)
|
50
|
+
end
|
51
|
+
|
52
|
+
@decode_use_source_tokenizer = false
|
53
|
+
|
54
|
+
_truncation = @tokenizer.truncation
|
55
|
+
|
56
|
+
if !_truncation.nil?
|
57
|
+
_truncation = _truncation.transform_keys(&:to_sym)
|
58
|
+
@tokenizer.enable_truncation(_truncation[:max_length], **_truncation.except(:max_length))
|
59
|
+
kwargs[:max_length] ||= _truncation[:max_length]
|
60
|
+
kwargs[:truncation_side] ||= _truncation[:direction]
|
61
|
+
kwargs[:stride] ||= _truncation[:stride]
|
62
|
+
kwargs[:truncation_strategy] ||= _truncation[:strategy]
|
63
|
+
else
|
64
|
+
@tokenizer.no_truncation
|
65
|
+
end
|
66
|
+
|
67
|
+
_padding = @tokenizer.padding
|
68
|
+
if !_padding.nil?
|
69
|
+
_padding = _padding.transform_keys(&:to_sym)
|
70
|
+
@tokenizer.enable_padding(**_padding)
|
71
|
+
kwargs[:pad_token] ||= _padding[:pad_token]
|
72
|
+
kwargs[:pad_token_type_id] ||= _padding[:pad_token_type_id]
|
73
|
+
kwargs[:padding_side] ||= _padding[:direction]
|
74
|
+
kwargs[:max_length] ||= _padding[:length]
|
75
|
+
kwargs[:pad_to_multiple_of] ||= _padding[:pad_to_multiple_of]
|
76
|
+
end
|
77
|
+
|
78
|
+
# We call this after having initialized the backend tokenizer because we update it.
|
79
|
+
super(**kwargs)
|
80
|
+
end
|
81
|
+
|
82
|
+
def is_fast
|
83
|
+
true
|
84
|
+
end
|
85
|
+
|
86
|
+
def get_vocab
|
87
|
+
@tokenizer.vocab(with_added_tokens: true)
|
88
|
+
end
|
89
|
+
|
90
|
+
def vocab
|
91
|
+
get_vocab
|
92
|
+
end
|
93
|
+
|
94
|
+
def convert_tokens_to_ids(tokens)
|
95
|
+
if tokens.nil?
|
96
|
+
return nil
|
97
|
+
end
|
98
|
+
|
99
|
+
if tokens.is_a?(String)
|
100
|
+
return _convert_token_to_id_with_added_voc(tokens)
|
101
|
+
end
|
102
|
+
|
103
|
+
ids = []
|
104
|
+
tokens.each do |token|
|
105
|
+
ids << _convert_token_to_id_with_added_voc(token)
|
106
|
+
end
|
107
|
+
ids
|
108
|
+
end
|
109
|
+
|
110
|
+
def _convert_token_to_id_with_added_voc(token)
|
111
|
+
index = @tokenizer.token_to_id(token)
|
112
|
+
if index.nil?
|
113
|
+
return unk_token_id
|
114
|
+
end
|
115
|
+
index
|
116
|
+
end
|
117
|
+
|
118
|
+
def convert_ids_to_tokens(ids, skip_special_tokens: false)
|
119
|
+
if ids.is_a?(Integer)
|
120
|
+
return @tokenizer.id_to_token(ids)
|
121
|
+
end
|
122
|
+
tokens = []
|
123
|
+
ids.each do |index|
|
124
|
+
index = index.to_i
|
125
|
+
if skip_special_tokens && @all_special_ids.include?(index)
|
126
|
+
next
|
127
|
+
end
|
128
|
+
tokens << @tokenizer.id_to_token(index)
|
129
|
+
end
|
130
|
+
tokens
|
131
|
+
end
|
132
|
+
|
133
|
+
private
|
134
|
+
|
135
|
+
def set_truncation_and_padding(
|
136
|
+
padding_strategy:,
|
137
|
+
truncation_strategy:,
|
138
|
+
max_length:,
|
139
|
+
stride:,
|
140
|
+
pad_to_multiple_of:
|
141
|
+
)
|
142
|
+
_truncation = @tokenizer.truncation
|
143
|
+
_padding = @tokenizer.padding
|
144
|
+
# Set truncation and padding on the backend tokenizer
|
145
|
+
if truncation_strategy == TruncationStrategy::DO_NOT_TRUNCATE
|
146
|
+
if !_truncation.nil?
|
147
|
+
@tokenizer.no_truncation
|
148
|
+
end
|
149
|
+
else
|
150
|
+
target = {
|
151
|
+
max_length: max_length,
|
152
|
+
stride: stride,
|
153
|
+
strategy: truncation_strategy,
|
154
|
+
direction: @truncation_side
|
155
|
+
}
|
156
|
+
|
157
|
+
# _truncation might contain more keys that the target `transformers`
|
158
|
+
# supports. Use only the target keys to trigger `enable_truncation`.
|
159
|
+
# This should enable this code to works on various `tokenizers`
|
160
|
+
# targets.
|
161
|
+
if _truncation.nil?
|
162
|
+
current = nil
|
163
|
+
else
|
164
|
+
current = target.to_h { |k, _| [k, _truncation[k]] }
|
165
|
+
end
|
166
|
+
|
167
|
+
if current != target
|
168
|
+
@tokenizer.enable_truncation(target.delete(:max_length), **target)
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
if padding_strategy == PaddingStrategy::DO_NOT_PAD
|
173
|
+
if !_padding.nil?
|
174
|
+
@tokenizer.no_padding
|
175
|
+
end
|
176
|
+
else
|
177
|
+
length = padding_strategy == PaddingStrategy::MAX_LENGTH ? max_length : nil
|
178
|
+
target = {
|
179
|
+
length: length,
|
180
|
+
direction: @padding_side,
|
181
|
+
pad_id: @pad_token_id,
|
182
|
+
pad_token: @pad_token,
|
183
|
+
pad_type_id: @pad_token_type_id,
|
184
|
+
pad_to_multiple_of: pad_to_multiple_of
|
185
|
+
}
|
186
|
+
if _padding != target
|
187
|
+
@tokenizer.enable_padding(**target)
|
188
|
+
end
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def _batch_encode_plus(
|
193
|
+
batch_text_or_text_pairs,
|
194
|
+
add_special_tokens: true,
|
195
|
+
padding_strategy: PaddingStrategy::DO_NOT_PAD,
|
196
|
+
truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
|
197
|
+
max_length: nil,
|
198
|
+
stride: 0,
|
199
|
+
is_split_into_words: false,
|
200
|
+
pad_to_multiple_of: nil,
|
201
|
+
return_tensors: nil,
|
202
|
+
return_token_type_ids: nil,
|
203
|
+
return_attention_mask: nil,
|
204
|
+
return_overflowing_tokens: false,
|
205
|
+
return_special_tokens_mask: false,
|
206
|
+
return_offsets_mapping: false,
|
207
|
+
return_length: false,
|
208
|
+
verbose: true
|
209
|
+
)
|
210
|
+
if !batch_text_or_text_pairs.is_a?(Array)
|
211
|
+
raise TypeError, "batch_text_or_text_pairs has to be an array (got #{batch_text_or_text_pairs.class.name})"
|
212
|
+
end
|
213
|
+
|
214
|
+
# Set the truncation and padding strategy and restore the initial configuration
|
215
|
+
set_truncation_and_padding(
|
216
|
+
padding_strategy: padding_strategy,
|
217
|
+
truncation_strategy: truncation_strategy,
|
218
|
+
max_length: max_length,
|
219
|
+
stride: stride,
|
220
|
+
pad_to_multiple_of: pad_to_multiple_of
|
221
|
+
)
|
222
|
+
|
223
|
+
encodings =
|
224
|
+
@tokenizer.encode_batch(
|
225
|
+
batch_text_or_text_pairs,
|
226
|
+
add_special_tokens: add_special_tokens,
|
227
|
+
is_pretokenized: is_split_into_words,
|
228
|
+
)
|
229
|
+
|
230
|
+
# Convert encoding to dict
|
231
|
+
# `Tokens` has type: Tuple[
|
232
|
+
# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
|
233
|
+
# List[EncodingFast]
|
234
|
+
# ]
|
235
|
+
# with nested dimensions corresponding to batch, overflows, sequence length
|
236
|
+
tokens_and_encodings =
|
237
|
+
encodings.map do |encoding|
|
238
|
+
_convert_encoding(
|
239
|
+
encoding: encoding,
|
240
|
+
return_token_type_ids: return_token_type_ids,
|
241
|
+
return_attention_mask: return_attention_mask,
|
242
|
+
return_overflowing_tokens: return_overflowing_tokens,
|
243
|
+
return_special_tokens_mask: return_special_tokens_mask,
|
244
|
+
return_offsets_mapping: return_offsets_mapping,
|
245
|
+
return_length: return_length,
|
246
|
+
verbose: verbose
|
247
|
+
)
|
248
|
+
end
|
249
|
+
|
250
|
+
# Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
|
251
|
+
# From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
|
252
|
+
# (we say ~ because the number of overflow varies with the example in the batch)
|
253
|
+
#
|
254
|
+
# To match each overflowing sample with the original sample in the batch
|
255
|
+
# we add an overflow_to_sample_mapping array (see below)
|
256
|
+
sanitized_tokens = {}
|
257
|
+
tokens_and_encodings[0][0].each_key do |key|
|
258
|
+
stack = tokens_and_encodings.map { |item, _| item[key][0] }
|
259
|
+
sanitized_tokens[key] = stack
|
260
|
+
end
|
261
|
+
sanitized_encodings = tokens_and_encodings.map { |_, item| item[0] }
|
262
|
+
|
263
|
+
# If returning overflowing tokens, we need to return a mapping
|
264
|
+
# from the batch idx to the original sample
|
265
|
+
if return_overflowing_tokens
|
266
|
+
overflow_to_sample_mapping = []
|
267
|
+
tokens_and_encodings.each_with_index do |(toks, _), i|
|
268
|
+
overflow_to_sample_mapping += [i] * toks["input_ids"].length
|
269
|
+
end
|
270
|
+
sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
|
271
|
+
end
|
272
|
+
|
273
|
+
sanitized_tokens["input_ids"].each do |input_ids|
|
274
|
+
_eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
|
275
|
+
end
|
276
|
+
|
277
|
+
BatchEncoding.new(data: sanitized_tokens, encoding: sanitized_encodings, tensor_type: return_tensors)
|
278
|
+
end
|
279
|
+
|
280
|
+
def _convert_encoding(
|
281
|
+
encoding:,
|
282
|
+
return_token_type_ids: nil,
|
283
|
+
return_attention_mask: nil,
|
284
|
+
return_overflowing_tokens: false,
|
285
|
+
return_special_tokens_mask: false,
|
286
|
+
return_offsets_mapping: false,
|
287
|
+
return_length: false,
|
288
|
+
verbose: true
|
289
|
+
)
|
290
|
+
if return_token_type_ids.nil?
|
291
|
+
return_token_type_ids = self.class.model_input_names.include?("token_type_ids")
|
292
|
+
end
|
293
|
+
if return_attention_mask.nil?
|
294
|
+
return_attention_mask = self.class.model_input_names.include?("attention_mask")
|
295
|
+
end
|
296
|
+
|
297
|
+
if return_overflowing_tokens && !encoding.overflowing.nil?
|
298
|
+
encodings = [encoding] + encoding.overflowing
|
299
|
+
else
|
300
|
+
encodings = [encoding]
|
301
|
+
end
|
302
|
+
|
303
|
+
encoding_dict = Hash.new { |h, k| h[k] = [] }
|
304
|
+
encodings.each do |e|
|
305
|
+
encoding_dict["input_ids"] << e.ids
|
306
|
+
|
307
|
+
if return_token_type_ids
|
308
|
+
encoding_dict["token_type_ids"] << e.type_ids
|
309
|
+
end
|
310
|
+
if return_attention_mask
|
311
|
+
encoding_dict["attention_mask"] << e.attention_mask
|
312
|
+
end
|
313
|
+
if return_special_tokens_mask
|
314
|
+
encoding_dict["special_tokens_mask"] << e.special_tokens_mask
|
315
|
+
end
|
316
|
+
if return_offsets_mapping
|
317
|
+
encoding_dict["offset_mapping"] << e.offsets
|
318
|
+
end
|
319
|
+
if return_length
|
320
|
+
encoding_dict["length"] << e.ids.length
|
321
|
+
end
|
322
|
+
end
|
323
|
+
|
324
|
+
[encoding_dict, encodings]
|
325
|
+
end
|
326
|
+
|
327
|
+
def _encode_plus(
|
328
|
+
text:,
|
329
|
+
text_pair: nil,
|
330
|
+
add_special_tokens: true,
|
331
|
+
padding_strategy: PaddingStrategy::DO_NOT_PAD,
|
332
|
+
truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
|
333
|
+
max_length: nil,
|
334
|
+
stride: 0,
|
335
|
+
is_split_into_words: false,
|
336
|
+
pad_to_multiple_of: nil,
|
337
|
+
return_tensors: nil,
|
338
|
+
return_token_type_ids: nil,
|
339
|
+
return_attention_mask: nil,
|
340
|
+
return_overflowing_tokens: false,
|
341
|
+
return_special_tokens_mask: false,
|
342
|
+
return_offsets_mapping: false,
|
343
|
+
return_length: false,
|
344
|
+
verbose: true,
|
345
|
+
**kwargs
|
346
|
+
)
|
347
|
+
batched_input = text_pair ? [[text, text_pair]] : [text]
|
348
|
+
batched_output =
|
349
|
+
_batch_encode_plus(
|
350
|
+
batched_input,
|
351
|
+
is_split_into_words: is_split_into_words,
|
352
|
+
add_special_tokens: add_special_tokens,
|
353
|
+
padding_strategy: padding_strategy,
|
354
|
+
truncation_strategy: truncation_strategy,
|
355
|
+
max_length: max_length,
|
356
|
+
stride: stride,
|
357
|
+
pad_to_multiple_of: pad_to_multiple_of,
|
358
|
+
return_tensors: return_tensors,
|
359
|
+
return_token_type_ids: return_token_type_ids,
|
360
|
+
return_attention_mask: return_attention_mask,
|
361
|
+
return_overflowing_tokens: return_overflowing_tokens,
|
362
|
+
return_special_tokens_mask: return_special_tokens_mask,
|
363
|
+
return_offsets_mapping: return_offsets_mapping,
|
364
|
+
return_length: return_length,
|
365
|
+
verbose: verbose,
|
366
|
+
**kwargs
|
367
|
+
)
|
368
|
+
|
369
|
+
# Return tensor is None, then we can remove the leading batch axis
|
370
|
+
# Overflowing tokens are returned as a batch of output so we keep them in this case
|
371
|
+
if return_tensors.nil? && !return_overflowing_tokens
|
372
|
+
batched_output =
|
373
|
+
BatchEncoding.new(
|
374
|
+
data: batched_output.items.to_h { |key, value|
|
375
|
+
[key, value.length > 0 && value[0].is_a?(Array) ? value[0] : value]
|
376
|
+
},
|
377
|
+
encoding: batched_output.encodings
|
378
|
+
)
|
379
|
+
end
|
380
|
+
|
381
|
+
_eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
|
382
|
+
|
383
|
+
batched_output
|
384
|
+
end
|
385
|
+
end
|
386
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module TorchUtils
|
17
|
+
def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
|
18
|
+
if chunk_size > 0
|
19
|
+
raise Todo
|
20
|
+
end
|
21
|
+
|
22
|
+
forward_fn.(*input_tensors)
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
WEIGHTS_NAME = "pytorch_model.bin"
|
17
|
+
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
18
|
+
TF2_WEIGHTS_NAME = "tf_model.h5"
|
19
|
+
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
20
|
+
TF_WEIGHTS_NAME = "model.ckpt"
|
21
|
+
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
22
|
+
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
23
|
+
SAFE_WEIGHTS_NAME = "model.safetensors"
|
24
|
+
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
25
|
+
CONFIG_NAME = "config.json"
|
26
|
+
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
27
|
+
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
|
28
|
+
PROCESSOR_NAME = "processor_config.json"
|
29
|
+
GENERATION_CONFIG_NAME = "generation_config.json"
|
30
|
+
MODEL_CARD_NAME = "modelcard.json"
|
31
|
+
end
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class ModelOutput
|
17
|
+
def self.attributes
|
18
|
+
@attributes ||= []
|
19
|
+
end
|
20
|
+
|
21
|
+
def self.attribute(attribute)
|
22
|
+
attributes << attribute.to_sym
|
23
|
+
|
24
|
+
define_method(attribute) do
|
25
|
+
self[attribute]
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def initialize(**kwargs)
|
30
|
+
@data = kwargs
|
31
|
+
end
|
32
|
+
|
33
|
+
def [](k)
|
34
|
+
if k.is_a?(String) || k.is_a?(Symbol)
|
35
|
+
@data[k.to_sym]
|
36
|
+
else
|
37
|
+
to_tuple[k]
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def to_tuple
|
42
|
+
self.class.attributes.map { |k| @data[k] }.compact
|
43
|
+
end
|
44
|
+
end
|
45
|
+
|
46
|
+
class ExplicitEnum
|
47
|
+
def initialize(value)
|
48
|
+
expected = self.class.constants.map { |k| self.class.const_get(k) }
|
49
|
+
unless expected.include?(value)
|
50
|
+
raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}"
|
51
|
+
end
|
52
|
+
@value = value
|
53
|
+
end
|
54
|
+
|
55
|
+
def to_s
|
56
|
+
@value
|
57
|
+
end
|
58
|
+
end
|
59
|
+
|
60
|
+
class PaddingStrategy < ExplicitEnum
|
61
|
+
LONGEST = "longest"
|
62
|
+
MAX_LENGTH = "max_length"
|
63
|
+
DO_NOT_PAD = "do_not_pad"
|
64
|
+
end
|
65
|
+
|
66
|
+
class TensorType < ExplicitEnum
|
67
|
+
PYTORCH = "pt"
|
68
|
+
TENSORFLOW = "tf"
|
69
|
+
NUMPY = "np"
|
70
|
+
JAX = "jax"
|
71
|
+
MLX = "mlx"
|
72
|
+
end
|
73
|
+
|
74
|
+
module Utils
|
75
|
+
def self.infer_framework(model_class)
|
76
|
+
if model_class < Torch::NN::Module
|
77
|
+
"pt"
|
78
|
+
else
|
79
|
+
raise TypeError, "Could not infer framework from class #{model_class}."
|
80
|
+
end
|
81
|
+
end
|
82
|
+
|
83
|
+
def self._is_numo(x)
|
84
|
+
x.is_a?(Numo::NArray)
|
85
|
+
end
|
86
|
+
|
87
|
+
def self.is_numo_array(x)
|
88
|
+
_is_numo(x)
|
89
|
+
end
|
90
|
+
|
91
|
+
def self._is_torch(x)
|
92
|
+
x.is_a?(Torch::Tensor)
|
93
|
+
end
|
94
|
+
|
95
|
+
def self.is_torch_tensor(x)
|
96
|
+
_is_torch(x)
|
97
|
+
end
|
98
|
+
|
99
|
+
def self._is_torch_device(x)
|
100
|
+
x.is_a?(Torch::Device)
|
101
|
+
end
|
102
|
+
|
103
|
+
def self.is_torch_device(x)
|
104
|
+
_is_torch_device(x)
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|