transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,165 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class ChannelDimension < ExplicitEnum
|
17
|
+
FIRST = "channels_first"
|
18
|
+
LAST = "channels_last"
|
19
|
+
end
|
20
|
+
|
21
|
+
module ImageUtils
|
22
|
+
def self.load_image(image, timeout: nil)
|
23
|
+
Utils.requires_backends(__method__, ["vision"])
|
24
|
+
if image.is_a?(URI)
|
25
|
+
require "open-uri"
|
26
|
+
|
27
|
+
image = Vips::Image.new_from_buffer(image.read(open_timeout: timeout, read_timeout: timeout), "")
|
28
|
+
elsif image.is_a?(String) && File.exist?(image)
|
29
|
+
image = Vips::Image.new_from_file(image)
|
30
|
+
elsif image.is_a?(Vips::Image)
|
31
|
+
image = image
|
32
|
+
else
|
33
|
+
raise ArgumentError, "Incorrect format used for image"
|
34
|
+
end
|
35
|
+
image
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.validate_preprocess_arguments(
|
39
|
+
do_rescale: nil,
|
40
|
+
rescale_factor: nil,
|
41
|
+
do_normalize: nil,
|
42
|
+
image_mean: nil,
|
43
|
+
image_std: nil,
|
44
|
+
do_pad: nil,
|
45
|
+
size_divisibility: nil,
|
46
|
+
do_center_crop: nil,
|
47
|
+
crop_size: nil,
|
48
|
+
do_resize: nil,
|
49
|
+
size: nil,
|
50
|
+
resample: nil
|
51
|
+
)
|
52
|
+
if do_rescale && rescale_factor.nil?
|
53
|
+
raise ArgumentError, "`rescale_factor` must be specified if `do_rescale` is `true`."
|
54
|
+
end
|
55
|
+
|
56
|
+
if do_pad && size_divisibility.nil?
|
57
|
+
# Here, size_divisor might be passed as the value of size
|
58
|
+
raise ArgumentError, "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `true`."
|
59
|
+
end
|
60
|
+
|
61
|
+
if do_normalize && (image_mean.nil? || image_std.nil?)
|
62
|
+
raise ArgumentError, "`image_mean` and `image_std` must both be specified if `do_normalize` is `true`."
|
63
|
+
end
|
64
|
+
|
65
|
+
if do_center_crop && crop_size.nil?
|
66
|
+
raise ArgumentError, "`crop_size` must be specified if `do_center_crop` is `true`."
|
67
|
+
end
|
68
|
+
|
69
|
+
if do_resize && (size.nil? || resample.nil?)
|
70
|
+
raise ArgumentError, "`size` and `resample` must be specified if `do_resize` is `true`."
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def self.make_list_of_images(images, expected_ndims: 3)
|
75
|
+
# TODO improve
|
76
|
+
images.is_a?(Array) ? images : [images]
|
77
|
+
end
|
78
|
+
|
79
|
+
def self.to_numo_array(img)
|
80
|
+
Numo::UInt8.from_binary(img.write_to_memory, [img.height, img.width, img.bands])
|
81
|
+
end
|
82
|
+
|
83
|
+
def self.infer_channel_dimension_format(
|
84
|
+
image, num_channels: nil
|
85
|
+
)
|
86
|
+
num_channels = !num_channels.nil? ? num_channels : [1, 3]
|
87
|
+
num_channels = num_channels.is_a?(Integer) ? [num_channels] : num_channels
|
88
|
+
|
89
|
+
if image.ndim == 3
|
90
|
+
first_dim, last_dim = 0, 2
|
91
|
+
elsif image.ndim == 4
|
92
|
+
first_dim, last_dim = 1, 3
|
93
|
+
else
|
94
|
+
raise ArgumentError, "Unsupported number of image dimensions: #{image.ndim}"
|
95
|
+
end
|
96
|
+
|
97
|
+
if num_channels.include?(image.shape[first_dim]) && num_channels.include?(image.shape[last_dim])
|
98
|
+
Transformers.logger.warn(
|
99
|
+
"The channel dimension is ambiguous. Got image shape #{image.shape}. Assuming channels are the first dimension."
|
100
|
+
)
|
101
|
+
return ChannelDimension::FIRST
|
102
|
+
elsif num_channels.include?(image.shape[first_dim])
|
103
|
+
return ChannelDimension::FIRST
|
104
|
+
elsif num_channels.include?(image.shape[last_dim])
|
105
|
+
return ChannelDimension::LAST
|
106
|
+
end
|
107
|
+
raise ArgumentError, "Unable to infer channel dimension format"
|
108
|
+
end
|
109
|
+
|
110
|
+
def self.get_channel_dimension_axis(
|
111
|
+
image, input_data_format: nil
|
112
|
+
)
|
113
|
+
if input_data_format.nil?
|
114
|
+
input_data_format = infer_channel_dimension_format(image)
|
115
|
+
end
|
116
|
+
if input_data_format == ChannelDimension::FIRST
|
117
|
+
return image.ndim - 3
|
118
|
+
elsif input_data_format == ChannelDimension::LAST
|
119
|
+
return image.ndim - 1
|
120
|
+
end
|
121
|
+
raise ArgumentError, "Unsupported data format: #{input_data_format}"
|
122
|
+
end
|
123
|
+
|
124
|
+
def self.is_vips_image(img)
|
125
|
+
Utils.is_vision_available && img.is_a?(Vips::Image)
|
126
|
+
end
|
127
|
+
|
128
|
+
def self.is_valid_image(img)
|
129
|
+
is_vips_image(img) || is_numo_array(img) || is_torch_tensor(img)
|
130
|
+
end
|
131
|
+
|
132
|
+
def self.valid_images(imgs)
|
133
|
+
# If we have an list of images, make sure every image is valid
|
134
|
+
if imgs.is_a?(Array)
|
135
|
+
imgs.each do |img|
|
136
|
+
if !valid_images(img)
|
137
|
+
return false
|
138
|
+
end
|
139
|
+
end
|
140
|
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
141
|
+
elsif !is_valid_image(imgs)
|
142
|
+
return false
|
143
|
+
end
|
144
|
+
true
|
145
|
+
end
|
146
|
+
|
147
|
+
def self.is_scaled_image(image)
|
148
|
+
if image.is_a?(Numo::UInt8)
|
149
|
+
return false
|
150
|
+
end
|
151
|
+
|
152
|
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
153
|
+
image.min >= 0 && image.max <= 1
|
154
|
+
end
|
155
|
+
|
156
|
+
def self.validate_kwargs(valid_processor_keys:, captured_kwargs:)
|
157
|
+
unused_keys = Set.new(captured_kwargs).difference(Set.new(valid_processor_keys))
|
158
|
+
if unused_keys.any?
|
159
|
+
unused_key_str = unused_keys.join(", ")
|
160
|
+
# TODO raise a warning here instead of simply logging?
|
161
|
+
Transformers.logger.warn("Unused or unrecognized kwargs: #{unused_key_str}.")
|
162
|
+
end
|
163
|
+
end
|
164
|
+
end
|
165
|
+
end
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BaseModelOutput < ModelOutput
|
17
|
+
attribute :last_hidden_state
|
18
|
+
attribute :hidden_states
|
19
|
+
attribute :attentions
|
20
|
+
end
|
21
|
+
|
22
|
+
class BaseModelOutputWithPooling < ModelOutput
|
23
|
+
attribute :last_hidden_state
|
24
|
+
attribute :pooler_output
|
25
|
+
attribute :hidden_states
|
26
|
+
attribute :attentions
|
27
|
+
end
|
28
|
+
|
29
|
+
class BaseModelOutputWithPoolingAndCrossAttentions < ModelOutput
|
30
|
+
attribute :last_hidden_state
|
31
|
+
attribute :pooler_output
|
32
|
+
attribute :hidden_states
|
33
|
+
attribute :past_key_values
|
34
|
+
attribute :attentions
|
35
|
+
attribute :cross_attentions
|
36
|
+
end
|
37
|
+
|
38
|
+
class BaseModelOutputWithPastAndCrossAttentions < ModelOutput
|
39
|
+
attribute :last_hidden_state
|
40
|
+
attribute :past_key_values
|
41
|
+
attribute :hidden_states
|
42
|
+
attribute :attentions
|
43
|
+
attribute :cross_attentions
|
44
|
+
end
|
45
|
+
|
46
|
+
class MaskedLMOutput < ModelOutput
|
47
|
+
attribute :loss
|
48
|
+
attribute :logits
|
49
|
+
attribute :hidden_states
|
50
|
+
attribute :attentions
|
51
|
+
end
|
52
|
+
|
53
|
+
class SequenceClassifierOutput < ModelOutput
|
54
|
+
attribute :loss
|
55
|
+
attribute :logits
|
56
|
+
attribute :hidden_states
|
57
|
+
attribute :attentions
|
58
|
+
end
|
59
|
+
|
60
|
+
class TokenClassifierOutput < ModelOutput
|
61
|
+
attribute :loss
|
62
|
+
attribute :logits
|
63
|
+
attribute :hidden_states
|
64
|
+
attribute :attentions
|
65
|
+
end
|
66
|
+
|
67
|
+
class QuestionAnsweringModelOutput < ModelOutput
|
68
|
+
attribute :loss
|
69
|
+
attribute :start_logits
|
70
|
+
attribute :end_logits
|
71
|
+
attribute :hidden_states
|
72
|
+
attribute :attentions
|
73
|
+
end
|
74
|
+
|
75
|
+
class ImageClassifierOutput < ModelOutput
|
76
|
+
attribute :loss
|
77
|
+
attribute :logits
|
78
|
+
attribute :hidden_states
|
79
|
+
attribute :attentions
|
80
|
+
end
|
81
|
+
end
|