transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,386 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class PreTrainedTokenizerFast < PreTrainedTokenizerBase
17
+ def initialize(*args, **kwargs)
18
+ tokenizer_object = kwargs.delete(:tokenizer_object)
19
+ slow_tokenizer = kwargs.delete(:__slow_tokenizer)
20
+ fast_tokenizer_file = kwargs.delete(:tokenizer_file)
21
+ from_slow = kwargs.delete(:from_slow) { false }
22
+ _added_tokens_decoder = kwargs.delete(:added_tokens_decoder)
23
+
24
+ if !tokenizer_object.nil?
25
+ fast_tokenizer = Copy.deepcopy(tokenizer_object)
26
+ elsif !fast_tokenizer_file.nil? && !from_slow
27
+ # We have a serialization from tokenizers which let us directly build the backend
28
+ fast_tokenizer = Tokenizers::Tokenizer.from_file(fast_tokenizer_file)
29
+ elsif !slow_tokenizer.nil?
30
+ # We need to convert a slow tokenizer to build the backend
31
+ fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
32
+ elsif !@slow_tokenizer_class.nil?
33
+ # We need to create and convert a slow tokenizer to build the backend
34
+ slow_tokenizer = @slow_tokenizer_class.new(*args, **kwargs)
35
+ fast_tokenizer = ConvertSlowTokenizer.convert_slow_tokenizer(slow_tokenizer)
36
+ else
37
+ raise ArgumentError, <<~MSG
38
+ Couldn't instantiate the backend tokenizer from one of:
39
+ (1) a `tokenizers` library serialization file,
40
+ (2) a slow tokenizer instance to convert or
41
+ (3) an equivalent slow tokenizer class to instantiate and convert.
42
+ You need to have sentencepiece installed to convert a slow tokenizer to a fast one.
43
+ MSG
44
+ end
45
+
46
+ @tokenizer = fast_tokenizer
47
+
48
+ if !slow_tokenizer.nil?
49
+ kwargs.merge!(slow_tokenizer.init_kwargs)
50
+ end
51
+
52
+ @decode_use_source_tokenizer = false
53
+
54
+ _truncation = @tokenizer.truncation
55
+
56
+ if !_truncation.nil?
57
+ _truncation = _truncation.transform_keys(&:to_sym)
58
+ @tokenizer.enable_truncation(_truncation[:max_length], **_truncation.except(:max_length))
59
+ kwargs[:max_length] ||= _truncation[:max_length]
60
+ kwargs[:truncation_side] ||= _truncation[:direction]
61
+ kwargs[:stride] ||= _truncation[:stride]
62
+ kwargs[:truncation_strategy] ||= _truncation[:strategy]
63
+ else
64
+ @tokenizer.no_truncation
65
+ end
66
+
67
+ _padding = @tokenizer.padding
68
+ if !_padding.nil?
69
+ _padding = _padding.transform_keys(&:to_sym)
70
+ @tokenizer.enable_padding(**_padding)
71
+ kwargs[:pad_token] ||= _padding[:pad_token]
72
+ kwargs[:pad_token_type_id] ||= _padding[:pad_token_type_id]
73
+ kwargs[:padding_side] ||= _padding[:direction]
74
+ kwargs[:max_length] ||= _padding[:length]
75
+ kwargs[:pad_to_multiple_of] ||= _padding[:pad_to_multiple_of]
76
+ end
77
+
78
+ # We call this after having initialized the backend tokenizer because we update it.
79
+ super(**kwargs)
80
+ end
81
+
82
+ def is_fast
83
+ true
84
+ end
85
+
86
+ def get_vocab
87
+ @tokenizer.vocab(with_added_tokens: true)
88
+ end
89
+
90
+ def vocab
91
+ get_vocab
92
+ end
93
+
94
+ def convert_tokens_to_ids(tokens)
95
+ if tokens.nil?
96
+ return nil
97
+ end
98
+
99
+ if tokens.is_a?(String)
100
+ return _convert_token_to_id_with_added_voc(tokens)
101
+ end
102
+
103
+ ids = []
104
+ tokens.each do |token|
105
+ ids << _convert_token_to_id_with_added_voc(token)
106
+ end
107
+ ids
108
+ end
109
+
110
+ def _convert_token_to_id_with_added_voc(token)
111
+ index = @tokenizer.token_to_id(token)
112
+ if index.nil?
113
+ return unk_token_id
114
+ end
115
+ index
116
+ end
117
+
118
+ def convert_ids_to_tokens(ids, skip_special_tokens: false)
119
+ if ids.is_a?(Integer)
120
+ return @tokenizer.id_to_token(ids)
121
+ end
122
+ tokens = []
123
+ ids.each do |index|
124
+ index = index.to_i
125
+ if skip_special_tokens && @all_special_ids.include?(index)
126
+ next
127
+ end
128
+ tokens << @tokenizer.id_to_token(index)
129
+ end
130
+ tokens
131
+ end
132
+
133
+ private
134
+
135
+ def set_truncation_and_padding(
136
+ padding_strategy:,
137
+ truncation_strategy:,
138
+ max_length:,
139
+ stride:,
140
+ pad_to_multiple_of:
141
+ )
142
+ _truncation = @tokenizer.truncation
143
+ _padding = @tokenizer.padding
144
+ # Set truncation and padding on the backend tokenizer
145
+ if truncation_strategy == TruncationStrategy::DO_NOT_TRUNCATE
146
+ if !_truncation.nil?
147
+ @tokenizer.no_truncation
148
+ end
149
+ else
150
+ target = {
151
+ max_length: max_length,
152
+ stride: stride,
153
+ strategy: truncation_strategy,
154
+ direction: @truncation_side
155
+ }
156
+
157
+ # _truncation might contain more keys that the target `transformers`
158
+ # supports. Use only the target keys to trigger `enable_truncation`.
159
+ # This should enable this code to works on various `tokenizers`
160
+ # targets.
161
+ if _truncation.nil?
162
+ current = nil
163
+ else
164
+ current = target.to_h { |k, _| [k, _truncation[k]] }
165
+ end
166
+
167
+ if current != target
168
+ @tokenizer.enable_truncation(target.delete(:max_length), **target)
169
+ end
170
+ end
171
+
172
+ if padding_strategy == PaddingStrategy::DO_NOT_PAD
173
+ if !_padding.nil?
174
+ @tokenizer.no_padding
175
+ end
176
+ else
177
+ length = padding_strategy == PaddingStrategy::MAX_LENGTH ? max_length : nil
178
+ target = {
179
+ length: length,
180
+ direction: @padding_side,
181
+ pad_id: @pad_token_id,
182
+ pad_token: @pad_token,
183
+ pad_type_id: @pad_token_type_id,
184
+ pad_to_multiple_of: pad_to_multiple_of
185
+ }
186
+ if _padding != target
187
+ @tokenizer.enable_padding(**target)
188
+ end
189
+ end
190
+ end
191
+
192
+ def _batch_encode_plus(
193
+ batch_text_or_text_pairs,
194
+ add_special_tokens: true,
195
+ padding_strategy: PaddingStrategy::DO_NOT_PAD,
196
+ truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
197
+ max_length: nil,
198
+ stride: 0,
199
+ is_split_into_words: false,
200
+ pad_to_multiple_of: nil,
201
+ return_tensors: nil,
202
+ return_token_type_ids: nil,
203
+ return_attention_mask: nil,
204
+ return_overflowing_tokens: false,
205
+ return_special_tokens_mask: false,
206
+ return_offsets_mapping: false,
207
+ return_length: false,
208
+ verbose: true
209
+ )
210
+ if !batch_text_or_text_pairs.is_a?(Array)
211
+ raise TypeError, "batch_text_or_text_pairs has to be an array (got #{batch_text_or_text_pairs.class.name})"
212
+ end
213
+
214
+ # Set the truncation and padding strategy and restore the initial configuration
215
+ set_truncation_and_padding(
216
+ padding_strategy: padding_strategy,
217
+ truncation_strategy: truncation_strategy,
218
+ max_length: max_length,
219
+ stride: stride,
220
+ pad_to_multiple_of: pad_to_multiple_of
221
+ )
222
+
223
+ encodings =
224
+ @tokenizer.encode_batch(
225
+ batch_text_or_text_pairs,
226
+ add_special_tokens: add_special_tokens,
227
+ is_pretokenized: is_split_into_words,
228
+ )
229
+
230
+ # Convert encoding to dict
231
+ # `Tokens` has type: Tuple[
232
+ # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
233
+ # List[EncodingFast]
234
+ # ]
235
+ # with nested dimensions corresponding to batch, overflows, sequence length
236
+ tokens_and_encodings =
237
+ encodings.map do |encoding|
238
+ _convert_encoding(
239
+ encoding: encoding,
240
+ return_token_type_ids: return_token_type_ids,
241
+ return_attention_mask: return_attention_mask,
242
+ return_overflowing_tokens: return_overflowing_tokens,
243
+ return_special_tokens_mask: return_special_tokens_mask,
244
+ return_offsets_mapping: return_offsets_mapping,
245
+ return_length: return_length,
246
+ verbose: verbose
247
+ )
248
+ end
249
+
250
+ # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
251
+ # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
252
+ # (we say ~ because the number of overflow varies with the example in the batch)
253
+ #
254
+ # To match each overflowing sample with the original sample in the batch
255
+ # we add an overflow_to_sample_mapping array (see below)
256
+ sanitized_tokens = {}
257
+ tokens_and_encodings[0][0].each_key do |key|
258
+ stack = tokens_and_encodings.map { |item, _| item[key][0] }
259
+ sanitized_tokens[key] = stack
260
+ end
261
+ sanitized_encodings = tokens_and_encodings.map { |_, item| item[0] }
262
+
263
+ # If returning overflowing tokens, we need to return a mapping
264
+ # from the batch idx to the original sample
265
+ if return_overflowing_tokens
266
+ overflow_to_sample_mapping = []
267
+ tokens_and_encodings.each_with_index do |(toks, _), i|
268
+ overflow_to_sample_mapping += [i] * toks["input_ids"].length
269
+ end
270
+ sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
271
+ end
272
+
273
+ sanitized_tokens["input_ids"].each do |input_ids|
274
+ _eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
275
+ end
276
+
277
+ BatchEncoding.new(data: sanitized_tokens, encoding: sanitized_encodings, tensor_type: return_tensors)
278
+ end
279
+
280
+ def _convert_encoding(
281
+ encoding:,
282
+ return_token_type_ids: nil,
283
+ return_attention_mask: nil,
284
+ return_overflowing_tokens: false,
285
+ return_special_tokens_mask: false,
286
+ return_offsets_mapping: false,
287
+ return_length: false,
288
+ verbose: true
289
+ )
290
+ if return_token_type_ids.nil?
291
+ return_token_type_ids = self.class.model_input_names.include?("token_type_ids")
292
+ end
293
+ if return_attention_mask.nil?
294
+ return_attention_mask = self.class.model_input_names.include?("attention_mask")
295
+ end
296
+
297
+ if return_overflowing_tokens && !encoding.overflowing.nil?
298
+ encodings = [encoding] + encoding.overflowing
299
+ else
300
+ encodings = [encoding]
301
+ end
302
+
303
+ encoding_dict = Hash.new { |h, k| h[k] = [] }
304
+ encodings.each do |e|
305
+ encoding_dict["input_ids"] << e.ids
306
+
307
+ if return_token_type_ids
308
+ encoding_dict["token_type_ids"] << e.type_ids
309
+ end
310
+ if return_attention_mask
311
+ encoding_dict["attention_mask"] << e.attention_mask
312
+ end
313
+ if return_special_tokens_mask
314
+ encoding_dict["special_tokens_mask"] << e.special_tokens_mask
315
+ end
316
+ if return_offsets_mapping
317
+ encoding_dict["offset_mapping"] << e.offsets
318
+ end
319
+ if return_length
320
+ encoding_dict["length"] << e.ids.length
321
+ end
322
+ end
323
+
324
+ [encoding_dict, encodings]
325
+ end
326
+
327
+ def _encode_plus(
328
+ text:,
329
+ text_pair: nil,
330
+ add_special_tokens: true,
331
+ padding_strategy: PaddingStrategy::DO_NOT_PAD,
332
+ truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
333
+ max_length: nil,
334
+ stride: 0,
335
+ is_split_into_words: false,
336
+ pad_to_multiple_of: nil,
337
+ return_tensors: nil,
338
+ return_token_type_ids: nil,
339
+ return_attention_mask: nil,
340
+ return_overflowing_tokens: false,
341
+ return_special_tokens_mask: false,
342
+ return_offsets_mapping: false,
343
+ return_length: false,
344
+ verbose: true,
345
+ **kwargs
346
+ )
347
+ batched_input = text_pair ? [[text, text_pair]] : [text]
348
+ batched_output =
349
+ _batch_encode_plus(
350
+ batched_input,
351
+ is_split_into_words: is_split_into_words,
352
+ add_special_tokens: add_special_tokens,
353
+ padding_strategy: padding_strategy,
354
+ truncation_strategy: truncation_strategy,
355
+ max_length: max_length,
356
+ stride: stride,
357
+ pad_to_multiple_of: pad_to_multiple_of,
358
+ return_tensors: return_tensors,
359
+ return_token_type_ids: return_token_type_ids,
360
+ return_attention_mask: return_attention_mask,
361
+ return_overflowing_tokens: return_overflowing_tokens,
362
+ return_special_tokens_mask: return_special_tokens_mask,
363
+ return_offsets_mapping: return_offsets_mapping,
364
+ return_length: return_length,
365
+ verbose: verbose,
366
+ **kwargs
367
+ )
368
+
369
+ # Return tensor is None, then we can remove the leading batch axis
370
+ # Overflowing tokens are returned as a batch of output so we keep them in this case
371
+ if return_tensors.nil? && !return_overflowing_tokens
372
+ batched_output =
373
+ BatchEncoding.new(
374
+ data: batched_output.items.to_h { |key, value|
375
+ [key, value.length > 0 && value[0].is_a?(Array) ? value[0] : value]
376
+ },
377
+ encoding: batched_output.encodings
378
+ )
379
+ end
380
+
381
+ _eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
382
+
383
+ batched_output
384
+ end
385
+ end
386
+ end
@@ -0,0 +1,25 @@
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module TorchUtils
17
+ def self.apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors)
18
+ if chunk_size > 0
19
+ raise Todo
20
+ end
21
+
22
+ forward_fn.(*input_tensors)
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,31 @@
1
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ WEIGHTS_NAME = "pytorch_model.bin"
17
+ WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
18
+ TF2_WEIGHTS_NAME = "tf_model.h5"
19
+ TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
20
+ TF_WEIGHTS_NAME = "model.ckpt"
21
+ FLAX_WEIGHTS_NAME = "flax_model.msgpack"
22
+ FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
23
+ SAFE_WEIGHTS_NAME = "model.safetensors"
24
+ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
25
+ CONFIG_NAME = "config.json"
26
+ FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
27
+ IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
28
+ PROCESSOR_NAME = "processor_config.json"
29
+ GENERATION_CONFIG_NAME = "generation_config.json"
30
+ MODEL_CARD_NAME = "modelcard.json"
31
+ end
@@ -0,0 +1,107 @@
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class ModelOutput
17
+ def self.attributes
18
+ @attributes ||= []
19
+ end
20
+
21
+ def self.attribute(attribute)
22
+ attributes << attribute.to_sym
23
+
24
+ define_method(attribute) do
25
+ self[attribute]
26
+ end
27
+ end
28
+
29
+ def initialize(**kwargs)
30
+ @data = kwargs
31
+ end
32
+
33
+ def [](k)
34
+ if k.is_a?(String) || k.is_a?(Symbol)
35
+ @data[k.to_sym]
36
+ else
37
+ to_tuple[k]
38
+ end
39
+ end
40
+
41
+ def to_tuple
42
+ self.class.attributes.map { |k| @data[k] }.compact
43
+ end
44
+ end
45
+
46
+ class ExplicitEnum
47
+ def initialize(value)
48
+ expected = self.class.constants.map { |k| self.class.const_get(k) }
49
+ unless expected.include?(value)
50
+ raise ArgumentError, "#{value} is not a valid #{self.class.name}, please select one of #{expected.inspect}"
51
+ end
52
+ @value = value
53
+ end
54
+
55
+ def to_s
56
+ @value
57
+ end
58
+ end
59
+
60
+ class PaddingStrategy < ExplicitEnum
61
+ LONGEST = "longest"
62
+ MAX_LENGTH = "max_length"
63
+ DO_NOT_PAD = "do_not_pad"
64
+ end
65
+
66
+ class TensorType < ExplicitEnum
67
+ PYTORCH = "pt"
68
+ TENSORFLOW = "tf"
69
+ NUMPY = "np"
70
+ JAX = "jax"
71
+ MLX = "mlx"
72
+ end
73
+
74
+ module Utils
75
+ def self.infer_framework(model_class)
76
+ if model_class < Torch::NN::Module
77
+ "pt"
78
+ else
79
+ raise TypeError, "Could not infer framework from class #{model_class}."
80
+ end
81
+ end
82
+
83
+ def self._is_numo(x)
84
+ x.is_a?(Numo::NArray)
85
+ end
86
+
87
+ def self.is_numo_array(x)
88
+ _is_numo(x)
89
+ end
90
+
91
+ def self._is_torch(x)
92
+ x.is_a?(Torch::Tensor)
93
+ end
94
+
95
+ def self.is_torch_tensor(x)
96
+ _is_torch(x)
97
+ end
98
+
99
+ def self._is_torch_device(x)
100
+ x.is_a?(Torch::Device)
101
+ end
102
+
103
+ def self.is_torch_device(x)
104
+ _is_torch_device(x)
105
+ end
106
+ end
107
+ end