transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,386 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class PreTrainedTokenizerFast < PreTrainedTokenizerBase
17
+ def initialize(*args, **kwargs)
18
+ tokenizer_object = kwargs.delete(:tokenizer_object)
19
+ slow_tokenizer = kwargs.delete(:__slow_tokenizer)
20
+ fast_tokenizer_file = kwargs.delete(:tokenizer_file)
21
+ from_slow = kwargs.delete(:from_slow) { false }
22
+ _added_tokens_decoder = kwargs.delete(:added_tokens_decoder)
23
+
24
+ if !tokenizer_object.nil?
25
+ fast_tokenizer = Copy.deepcopy(tokenizer_object)
26
+ elsif !fast_tokenizer_file.nil? && !from_slow
27
+ # We have a serialization from tokenizers which let us directly build the backend
28
+ fast_tokenizer = Tokenizers::Tokenizer.from_file(fast_tokenizer_file)
29
+ elsif !slow_tokenizer.nil?
30
+ # We need to convert a slow tokenizer to build the backend
31
+ fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
32
+ elsif !@slow_tokenizer_class.nil?
33
+ # We need to create and convert a slow tokenizer to build the backend
34
+ slow_tokenizer = @slow_tokenizer_class.new(*args, **kwargs)
35
+ fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
36
+ else
37
+ raise ArgumentError, <<~MSG
38
+ Couldn't instantiate the backend tokenizer from one of:
39
+ (1) a `tokenizers` library serialization file,
40
+ (2) a slow tokenizer instance to convert or
41
+ (3) an equivalent slow tokenizer class to instantiate and convert.
42
+ You need to have sentencepiece installed to convert a slow tokenizer to a fast one.
43
+ MSG
44
+ end
45
+
46
+ @tokenizer = fast_tokenizer
47
+
48
+ if !slow_tokenizer.nil?
49
+ kwargs.merge!(slow_tokenizer.init_kwargs)
50
+ end
51
+
52
+ @decode_use_source_tokenizer = false
53
+
54
+ _truncation = @tokenizer.truncation
55
+
56
+ if !_truncation.nil?
57
+ _truncation = _truncation.transform_keys(&:to_sym)
58
+ @tokenizer.enable_truncation(_truncation[:max_length], **_truncation.except(:max_length))
59
+ kwargs[:max_length] ||= _truncation[:max_length]
60
+ kwargs[:truncation_side] ||= _truncation[:direction]
61
+ kwargs[:stride] ||= _truncation[:stride]
62
+ kwargs[:truncation_strategy] ||= _truncation[:strategy]
63
+ else
64
+ @tokenizer.no_truncation
65
+ end
66
+
67
+ _padding = @tokenizer.padding
68
+ if !_padding.nil?
69
+ _padding = _padding.transform_keys(&:to_sym)
70
+ @tokenizer.enable_padding(**_padding)
71
+ kwargs[:pad_token] ||= _padding[:pad_token]
72
+ kwargs[:pad_token_type_id] ||= _padding[:pad_token_type_id]
73
+ kwargs[:padding_side] ||= _padding[:direction]
74
+ kwargs[:max_length] ||= _padding[:length]
75
+ kwargs[:pad_to_multiple_of] ||= _padding[:pad_to_multiple_of]
76
+ end
77
+
78
+ # We call this after having initialized the backend tokenizer because we update it.
79
+ super(**kwargs)
80
+ end
81
+
82
+ def is_fast
83
+ true
84
+ end
85
+
86
+ def get_vocab
87
+ @tokenizer.vocab(with_added_tokens: true)
88
+ end
89
+
90
+ def vocab
91
+ get_vocab
92
+ end
93
+
94
+ def convert_tokens_to_ids(tokens)
95
+ if tokens.nil?
96
+ return nil
97
+ end
98
+
99
+ if tokens.is_a?(String)
100
+ return _convert_token_to_id_with_added_voc(tokens)
101
+ end
102
+
103
+ ids = []
104
+ tokens.each do |token|
105
+ ids << _convert_token_to_id_with_added_voc(token)
106
+ end
107
+ ids
108
+ end
109
+
110
+ def _convert_token_to_id_with_added_voc(token)
111
+ index = @tokenizer.token_to_id(token)
112
+ if index.nil?
113
+ return unk_token_id
114
+ end
115
+ index
116
+ end
117
+
118
+ def convert_ids_to_tokens(ids, skip_special_tokens: false)
119
+ if ids.is_a?(Integer)
120
+ return @tokenizer.id_to_token(ids)
121
+ end
122
+ tokens = []
123
+ ids.each do |index|
124
+ index = index.to_i
125
+ if skip_special_tokens && @all_special_ids.include?(index)
126
+ next
127
+ end
128
+ tokens << @tokenizer.id_to_token(index)
129
+ end
130
+ tokens
131
+ end
132
+
133
+ private
134
+
135
+ def set_truncation_and_padding(
136
+ padding_strategy:,
137
+ truncation_strategy:,
138
+ max_length:,
139
+ stride:,
140
+ pad_to_multiple_of:
141
+ )
142
+ _truncation = @tokenizer.truncation
143
+ _padding = @tokenizer.padding
144
+ # Set truncation and padding on the backend tokenizer
145
+ if truncation_strategy == TruncationStrategy::DO_NOT_TRUNCATE
146
+ if !_truncation.nil?
147
+ @tokenizer.no_truncation
148
+ end
149
+ else
150
+ target = {
151
+ max_length: max_length,
152
+ stride: stride,
153
+ strategy: truncation_strategy,
154
+ direction: @truncation_side
155
+ }
156
+
157
+ # _truncation might contain more keys that the target `transformers`
158
+ # supports. Use only the target keys to trigger `enable_truncation`.
159
+ # This should enable this code to works on various `tokenizers`
160
+ # targets.
161
+ if _truncation.nil?
162
+ current = nil
163
+ else
164
+ current = target.to_h { |k, _| [k, _truncation[k]] }
165
+ end
166
+
167
+ if current != target
168
+ @tokenizer.enable_truncation(target.delete(:max_length), **target)
169
+ end
170
+ end
171
+
172
+ if padding_strategy == PaddingStrategy::DO_NOT_PAD
173
+ if !_padding.nil?
174
+ @tokenizer.no_padding
175
+ end
176
+ else
177
+ length = padding_strategy == PaddingStrategy::MAX_LENGTH ? max_length : nil
178
+ target = {
179
+ length: length,
180
+ direction: @padding_side,
181
+ pad_id: @pad_token_id,
182
+ pad_token: @pad_token,
183
+ pad_type_id: @pad_token_type_id,
184
+ pad_to_multiple_of: pad_to_multiple_of
185
+ }
186
+ if _padding != target
187
+ @tokenizer.enable_padding(**target)
188
+ end
189
+ end
190
+ end
191
+
192
+ def _batch_encode_plus(
193
+ batch_text_or_text_pairs,
194
+ add_special_tokens: true,
195
+ padding_strategy: PaddingStrategy::DO_NOT_PAD,
196
+ truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
197
+ max_length: nil,
198
+ stride: 0,
199
+ is_split_into_words: false,
200
+ pad_to_multiple_of: nil,
201
+ return_tensors: nil,
202
+ return_token_type_ids: nil,
203
+ return_attention_mask: nil,
204
+ return_overflowing_tokens: false,
205
+ return_special_tokens_mask: false,
206
+ return_offsets_mapping: false,
207
+ return_length: false,
208
+ verbose: true
209
+ )
210
+ if !batch_text_or_text_pairs.is_a?(Array)
211
+ raise TypeError, "batch_text_or_text_pairs has to be an array (got #{batch_text_or_text_pairs.class.name})"
212
+ end
213
+
214
+ # Set the truncation and padding strategy and restore the initial configuration
215
+ set_truncation_and_padding(
216
+ padding_strategy: padding_strategy,
217
+ truncation_strategy: truncation_strategy,
218
+ max_length: max_length,
219
+ stride: stride,
220
+ pad_to_multiple_of: pad_to_multiple_of
221
+ )
222
+
223
+ encodings =
224
+ @tokenizer.encode_batch(
225
+ batch_text_or_text_pairs,
226
+ add_special_tokens: add_special_tokens,
227
+ is_pretokenized: is_split_into_words,
228
+ )
229
+
230
+ # Convert encoding to dict
231
+ # `Tokens` has type: Tuple[
232
+ # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
233
+ # List[EncodingFast]
234
+ # ]
235
+ # with nested dimensions corresponding to batch, overflows, sequence length
236
+ tokens_and_encodings =
237
+ encodings.map do |encoding|
238
+ _convert_encoding(
239
+ encoding: encoding,
240
+ return_token_type_ids: return_token_type_ids,
241
+ return_attention_mask: return_attention_mask,
242
+ return_overflowing_tokens: return_overflowing_tokens,
243
+ return_special_tokens_mask: return_special_tokens_mask,
244
+ return_offsets_mapping: return_offsets_mapping,
245
+ return_length: return_length,
246
+ verbose: verbose
247
+ )
248
+ end
249
+
250
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
251
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
252
+ # (we say ~ because the number of overflow varies with the example in the batch)
253
+ #
254
+ # To match each overflowing sample with the original sample in the batch
255
+ # we add an overflow_to_sample_mapping array (see below)
256
+ sanitized_tokens = {}
257
+ tokens_and_encodings[0][0].each_key do |key|
258
+ stack = tokens_and_encodings.map { |item, _| item[key][0] }
259
+ sanitized_tokens[key] = stack
260
+ end
261
+ sanitized_encodings = tokens_and_encodings.map { |_, item| item[0] }
262
+
263
+ # If returning overflowing tokens, we need to return a mapping
264
+ # from the batch idx to the original sample
265
+ if return_overflowing_tokens
266
+ overflow_to_sample_mapping = []
267
+ tokens_and_encodings.each_with_index do |(toks, _), i|
268
+ overflow_to_sample_mapping += [i] * toks["input_ids"].length
269
+ end
270
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
271
+ end
272
+
273
+ sanitized_tokens["input_ids"].each do |input_ids|
274
+ _eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
275
+ end
276
+
277
+ BatchEncoding.new(data: sanitized_tokens, encoding: sanitized_encodings, tensor_type: return_tensors)
278
+ end
279
+
280
+ def _convert_encoding(
281
+ encoding:,
282
+ return_token_type_ids: nil,
283
+ return_attention_mask: nil,
284
+ return_overflowing_tokens: false,
285
+ return_special_tokens_mask: false,
286
+ return_offsets_mapping: false,
287
+ return_length: false,
288
+ verbose: true
289
+ )
290
+ if return_token_type_ids.nil?
291
+ return_token_type_ids = self.class.model_input_names.include?("token_type_ids")
292
+ end
293
+ if return_attention_mask.nil?
294
+ return_attention_mask = self.class.model_input_names.include?("attention_mask")
295
+ end
296
+
297
+ if return_overflowing_tokens && !encoding.overflowing.nil?
298
+ encodings = [encoding] + encoding.overflowing
299
+ else
300
+ encodings = [encoding]
301
+ end
302
+
303
+ encoding_dict = Hash.new { |h, k| h[k] = [] }
304
+ encodings.each do |e|
305
+ encoding_dict["input_ids"] << e.ids
306
+
307
+ if return_token_type_ids
308
+ encoding_dict["token_type_ids"] << e.type_ids
309
+ end
310
+ if return_attention_mask
311
+ encoding_dict["attention_mask"] << e.attention_mask
312
+ end
313
+ if return_special_tokens_mask
314
+ encoding_dict["special_tokens_mask"] << e.special_tokens_mask
315
+ end
316
+ if return_offsets_mapping
317
+ encoding_dict["offset_mapping"] << e.offsets
318
+ end
319
+ if return_length
320
+ encoding_dict["length"] << e.ids.length
321
+ end
322
+ end
323
+
324
+ [encoding_dict, encodings]
325
+ end
326
+
327
+ def _encode_plus(
328
+ text:,
329
+ text_pair: nil,
330
+ add_special_tokens: true,
331
+ padding_strategy: PaddingStrategy::DO_NOT_PAD,
332
+ truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
333
+ max_length: nil,
334
+ stride: 0,
335
+ is_split_into_words: false,
336
+ pad_to_multiple_of: nil,
337
+ return_tensors: nil,
338
+ return_token_type_ids: nil,
339
+ return_attention_mask: nil,
340
+ return_overflowing_tokens: false,
341
+ return_special_tokens_mask: false,
342
+ return_offsets_mapping: false,
343
+ return_length: false,
344
+ verbose: true,
345
+ **kwargs
346
+ )
347
+ batched_input = text_pair ? [[text, text_pair]] : [text]
348
+ batched_output =
349
+ _batch_encode_plus(
350
+ batched_input,
351
+ is_split_into_words: is_split_into_words,
352
+ add_special_tokens: add_special_tokens,
353
+ padding_strategy: padding_strategy,
354
+ truncation_strategy: truncation_strategy,
355
+ max_length: max_length,
356
+ stride: stride,
357
+ pad_to_multiple_of: pad_to_multiple_of,
358
+ return_tensors: return_tensors,
359
+ return_token_type_ids: return_token_type_ids,
360
+ return_attention_mask: return_attention_mask,
361
+ return_overflowing_tokens: return_overflowing_tokens,
362
+ return_special_tokens_mask: return_special_tokens_mask,
363
+ return_offsets_mapping: return_offsets_mapping,
364
+ return_length: return_length,
365
+ verbose: verbose,
366
+ **kwargs
367
+ )
368
+
369
+ # Return tensor is None, then we can remove the leading batch axis
370
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
371
+ if return_tensors.nil? && !return_overflowing_tokens
372
+ batched_output =
373
+ BatchEncoding.new(
374
+ data: batched_output.items.to_h { |key, value|
375
+ [key, value.length > 0 && value[0].is_a?(Array) ? value[0] : value]
376
+ },
377
+ encoding: batched_output.encodings
378
+ )
379
+ end
380
+
381
+ _eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
382
+
383
+ batched_output
384
+ end
385
+ end
386
+ end
@@ -0,0 +1,25 @@
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module TorchUtils
17
+ def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
18
+ if chunk_size > 0
19
+ raise Todo
20
+ end
21
+
22
+ forward_fn.(*input_tensors)
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,31 @@
1
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ WEIGHTS_NAME = "pytorch_model.bin"
17
+ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
18
+ TF2_WEIGHTS_NAME = "tf_model.h5"
19
+ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
20
+ TF_WEIGHTS_NAME = "model.ckpt"
21
+ FLAX_WEIGHTS_NAME = "flax_model.msgpack"
22
+ FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
23
+ SAFE_WEIGHTS_NAME = "model.safetensors"
24
+ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
25
+ CONFIG_NAME = "config.json"
26
+ FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
27
+ IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
28
+ PROCESSOR_NAME = "processor_config.json"
29
+ GENERATION_CONFIG_NAME = "generation_config.json"
30
+ MODEL_CARD_NAME = "modelcard.json"
31
+ end
@@ -0,0 +1,107 @@
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class ModelOutput
17
+ def self.attributes
18
+ @attributes ||= []
19
+ end
20
+
21
+ def self.attribute(attribute)
22
+ attributes << attribute.to_sym
23
+
24
+ define_method(attribute) do
25
+ self[attribute]
26
+ end
27
+ end
28
+
29
+ def initialize(**kwargs)
30
+ @data = kwargs
31
+ end
32
+
33
+ def [](k)
34
+ if k.is_a?(String) || k.is_a?(Symbol)
35
+ @data[k.to_sym]
36
+ else
37
+ to_tuple[k]
38
+ end
39
+ end
40
+
41
+ def to_tuple
42
+ self.class.attributes.map { |k| @data[k] }.compact
43
+ end
44
+ end
45
+
46
+ class ExplicitEnum
47
+ def initialize(value)
48
+ expected = self.class.constants.map { |k| self.class.const_get(k) }
49
+ unless expected.include?(value)
50
+ raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}"
51
+ end
52
+ @value = value
53
+ end
54
+
55
+ def to_s
56
+ @value
57
+ end
58
+ end
59
+
60
+ class PaddingStrategy < ExplicitEnum
61
+ LONGEST = "longest"
62
+ MAX_LENGTH = "max_length"
63
+ DO_NOT_PAD = "do_not_pad"
64
+ end
65
+
66
+ class TensorType < ExplicitEnum
67
+ PYTORCH = "pt"
68
+ TENSORFLOW = "tf"
69
+ NUMPY = "np"
70
+ JAX = "jax"
71
+ MLX = "mlx"
72
+ end
73
+
74
+ module Utils
75
+ def self.infer_framework(model_class)
76
+ if model_class < Torch::NN::Module
77
+ "pt"
78
+ else
79
+ raise TypeError, "Could not infer framework from class #{model_class}."
80
+ end
81
+ end
82
+
83
+ def self._is_numo(x)
84
+ x.is_a?(Numo::NArray)
85
+ end
86
+
87
+ def self.is_numo_array(x)
88
+ _is_numo(x)
89
+ end
90
+
91
+ def self._is_torch(x)
92
+ x.is_a?(Torch::Tensor)
93
+ end
94
+
95
+ def self.is_torch_tensor(x)
96
+ _is_torch(x)
97
+ end
98
+
99
+ def self._is_torch_device(x)
100
+ x.is_a?(Torch::Device)
101
+ end
102
+
103
+ def self.is_torch_device(x)
104
+ _is_torch_device(x)
105
+ end
106
+ end
107
+ end