transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,165 @@
1
+ # Copyright 2021 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class ChannelDimension < ExplicitEnum
17
+ FIRST = "channels_first"
18
+ LAST = "channels_last"
19
+ end
20
+
21
+ module ImageUtils
22
+ def self.load_image(image, timeout: nil)
23
+ Utils.requires_backends(__method__, ["vision"])
24
+ if image.is_a?(URI)
25
+ require "open-uri"
26
+
27
+ image = Vips::Image.new_from_buffer(image.read(open_timeout: timeout, read_timeout: timeout), "")
28
+ elsif image.is_a?(String) && File.exist?(image)
29
+ image = Vips::Image.new_from_file(image)
30
+ elsif image.is_a?(Vips::Image)
31
+ image = image
32
+ else
33
+ raise ArgumentError, "Incorrect format used for image"
34
+ end
35
+ image
36
+ end
37
+
38
+ def self.validate_preprocess_arguments(
39
+ do_rescale: nil,
40
+ rescale_factor: nil,
41
+ do_normalize: nil,
42
+ image_mean: nil,
43
+ image_std: nil,
44
+ do_pad: nil,
45
+ size_divisibility: nil,
46
+ do_center_crop: nil,
47
+ crop_size: nil,
48
+ do_resize: nil,
49
+ size: nil,
50
+ resample: nil
51
+ )
52
+ if do_rescale && rescale_factor.nil?
53
+ raise ArgumentError, "`rescale_factor` must be specified if `do_rescale` is `true`."
54
+ end
55
+
56
+ if do_pad && size_divisibility.nil?
57
+ # Here, size_divisor might be passed as the value of size
58
+ raise ArgumentError, "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `true`."
59
+ end
60
+
61
+ if do_normalize && (image_mean.nil? || image_std.nil?)
62
+ raise ArgumentError, "`image_mean` and `image_std` must both be specified if `do_normalize` is `true`."
63
+ end
64
+
65
+ if do_center_crop && crop_size.nil?
66
+ raise ArgumentError, "`crop_size` must be specified if `do_center_crop` is `true`."
67
+ end
68
+
69
+ if do_resize && (size.nil? || resample.nil?)
70
+ raise ArgumentError, "`size` and `resample` must be specified if `do_resize` is `true`."
71
+ end
72
+ end
73
+
74
+ def self.make_list_of_images(images, expected_ndims: 3)
75
+ # TODO improve
76
+ images.is_a?(Array) ? images : [images]
77
+ end
78
+
79
+ def self.to_numo_array(img)
80
+ Numo::UInt8.from_binary(img.write_to_memory, [img.height, img.width, img.bands])
81
+ end
82
+
83
+ def self.infer_channel_dimension_format(
84
+ image, num_channels: nil
85
+ )
86
+ num_channels = !num_channels.nil? ? num_channels : [1, 3]
87
+ num_channels = num_channels.is_a?(Integer) ? [num_channels] : num_channels
88
+
89
+ if image.ndim == 3
90
+ first_dim, last_dim = 0, 2
91
+ elsif image.ndim == 4
92
+ first_dim, last_dim = 1, 3
93
+ else
94
+ raise ArgumentError, "Unsupported number of image dimensions: #{image.ndim}"
95
+ end
96
+
97
+ if num_channels.include?(image.shape[first_dim]) && num_channels.include?(image.shape[last_dim])
98
+ Transformers.logger.warn(
99
+ "The channel dimension is ambiguous. Got image shape #{image.shape}. Assuming channels are the first dimension."
100
+ )
101
+ return ChannelDimension::FIRST
102
+ elsif num_channels.include?(image.shape[first_dim])
103
+ return ChannelDimension::FIRST
104
+ elsif num_channels.include?(image.shape[last_dim])
105
+ return ChannelDimension::LAST
106
+ end
107
+ raise ArgumentError, "Unable to infer channel dimension format"
108
+ end
109
+
110
+ def self.get_channel_dimension_axis(
111
+ image, input_data_format: nil
112
+ )
113
+ if input_data_format.nil?
114
+ input_data_format = infer_channel_dimension_format(image)
115
+ end
116
+ if input_data_format == ChannelDimension::FIRST
117
+ return image.ndim - 3
118
+ elsif input_data_format == ChannelDimension::LAST
119
+ return image.ndim - 1
120
+ end
121
+ raise ArgumentError, "Unsupported data format: #{input_data_format}"
122
+ end
123
+
124
+ def self.is_vips_image(img)
125
+ Utils.is_vision_available && img.is_a?(Vips::Image)
126
+ end
127
+
128
+ def self.is_valid_image(img)
129
+ is_vips_image(img) || is_numo_array(img) || is_torch_tensor(img)
130
+ end
131
+
132
+ def self.valid_images(imgs)
133
+ # If we have an list of images, make sure every image is valid
134
+ if imgs.is_a?(Array)
135
+ imgs.each do |img|
136
+ if !valid_images(img)
137
+ return false
138
+ end
139
+ end
140
+ # If not a list of tuple, we have been given a single image or batched tensor of images
141
+ elsif !is_valid_image(imgs)
142
+ return false
143
+ end
144
+ true
145
+ end
146
+
147
+ def self.is_scaled_image(image)
148
+ if image.is_a?(Numo::UInt8)
149
+ return false
150
+ end
151
+
152
+ # It's possible the image has pixel values in [0, 255] but is of floating type
153
+ image.min >= 0 && image.max <= 1
154
+ end
155
+
156
+ def self.validate_kwargs(valid_processor_keys:, captured_kwargs:)
157
+ unused_keys = Set.new(captured_kwargs).difference(Set.new(valid_processor_keys))
158
+ if unused_keys.any?
159
+ unused_key_str = unused_keys.join(", ")
160
+ # TODO raise a warning here instead of simply logging?
161
+ Transformers.logger.warn("Unused or unrecognized kwargs: #{unused_key_str}.")
162
+ end
163
+ end
164
+ end
165
+ end
@@ -0,0 +1,81 @@
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BaseModelOutput < ModelOutput
17
+ attribute :last_hidden_state
18
+ attribute :hidden_states
19
+ attribute :attentions
20
+ end
21
+
22
+ class BaseModelOutputWithPooling < ModelOutput
23
+ attribute :last_hidden_state
24
+ attribute :pooler_output
25
+ attribute :hidden_states
26
+ attribute :attentions
27
+ end
28
+
29
+ class BaseModelOutputWithPoolingAndCrossAttentions < ModelOutput
30
+ attribute :last_hidden_state
31
+ attribute :pooler_output
32
+ attribute :hidden_states
33
+ attribute :past_key_values
34
+ attribute :attentions
35
+ attribute :cross_attentions
36
+ end
37
+
38
+ class BaseModelOutputWithPastAndCrossAttentions < ModelOutput
39
+ attribute :last_hidden_state
40
+ attribute :past_key_values
41
+ attribute :hidden_states
42
+ attribute :attentions
43
+ attribute :cross_attentions
44
+ end
45
+
46
+ class MaskedLMOutput < ModelOutput
47
+ attribute :loss
48
+ attribute :logits
49
+ attribute :hidden_states
50
+ attribute :attentions
51
+ end
52
+
53
+ class SequenceClassifierOutput < ModelOutput
54
+ attribute :loss
55
+ attribute :logits
56
+ attribute :hidden_states
57
+ attribute :attentions
58
+ end
59
+
60
+ class TokenClassifierOutput < ModelOutput
61
+ attribute :loss
62
+ attribute :logits
63
+ attribute :hidden_states
64
+ attribute :attentions
65
+ end
66
+
67
+ class QuestionAnsweringModelOutput < ModelOutput
68
+ attribute :loss
69
+ attribute :start_logits
70
+ attribute :end_logits
71
+ attribute :hidden_states
72
+ attribute :attentions
73
+ end
74
+
75
+ class ImageClassifierOutput < ModelOutput
76
+ attribute :loss
77
+ attribute :logits
78
+ attribute :hidden_states
79
+ attribute :attentions
80
+ end
81
+ end