transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,165 @@
|
|
1
|
+
# Copyright 2021 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class ChannelDimension < ExplicitEnum
|
17
|
+
FIRST = "channels_first"
|
18
|
+
LAST = "channels_last"
|
19
|
+
end
|
20
|
+
|
21
|
+
module ImageUtils
|
22
|
+
def self.load_image(image, timeout: nil)
|
23
|
+
Utils.requires_backends(__method__, ["vision"])
|
24
|
+
if image.is_a?(URI)
|
25
|
+
require "open-uri"
|
26
|
+
|
27
|
+
image = Vips::Image.new_from_buffer(image.read(open_timeout: timeout, read_timeout: timeout), "")
|
28
|
+
elsif image.is_a?(String) && File.exist?(image)
|
29
|
+
image = Vips::Image.new_from_file(image)
|
30
|
+
elsif image.is_a?(Vips::Image)
|
31
|
+
image = image
|
32
|
+
else
|
33
|
+
raise ArgumentError, "Incorrect format used for image"
|
34
|
+
end
|
35
|
+
image
|
36
|
+
end
|
37
|
+
|
38
|
+
def self.validate_preprocess_arguments(
|
39
|
+
do_rescale: nil,
|
40
|
+
rescale_factor: nil,
|
41
|
+
do_normalize: nil,
|
42
|
+
image_mean: nil,
|
43
|
+
image_std: nil,
|
44
|
+
do_pad: nil,
|
45
|
+
size_divisibility: nil,
|
46
|
+
do_center_crop: nil,
|
47
|
+
crop_size: nil,
|
48
|
+
do_resize: nil,
|
49
|
+
size: nil,
|
50
|
+
resample: nil
|
51
|
+
)
|
52
|
+
if do_rescale && rescale_factor.nil?
|
53
|
+
raise ArgumentError, "`rescale_factor` must be specified if `do_rescale` is `true`."
|
54
|
+
end
|
55
|
+
|
56
|
+
if do_pad && size_divisibility.nil?
|
57
|
+
# Here, size_divisor might be passed as the value of size
|
58
|
+
raise ArgumentError, "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `true`."
|
59
|
+
end
|
60
|
+
|
61
|
+
if do_normalize && (image_mean.nil? || image_std.nil?)
|
62
|
+
raise ArgumentError, "`image_mean` and `image_std` must both be specified if `do_normalize` is `true`."
|
63
|
+
end
|
64
|
+
|
65
|
+
if do_center_crop && crop_size.nil?
|
66
|
+
raise ArgumentError, "`crop_size` must be specified if `do_center_crop` is `true`."
|
67
|
+
end
|
68
|
+
|
69
|
+
if do_resize && (size.nil? || resample.nil?)
|
70
|
+
raise ArgumentError, "`size` and `resample` must be specified if `do_resize` is `true`."
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def self.make_list_of_images(images, expected_ndims: 3)
|
75
|
+
# TODO improve
|
76
|
+
images.is_a?(Array) ? images : [images]
|
77
|
+
end
|
78
|
+
|
79
|
+
def self.to_numo_array(img)
|
80
|
+
Numo::UInt8.from_binary(img.write_to_memory, [img.height, img.width, img.bands])
|
81
|
+
end
|
82
|
+
|
83
|
+
def self.infer_channel_dimension_format(
|
84
|
+
image, num_channels: nil
|
85
|
+
)
|
86
|
+
num_channels = !num_channels.nil? ? num_channels : [1, 3]
|
87
|
+
num_channels = num_channels.is_a?(Integer) ? [num_channels] : num_channels
|
88
|
+
|
89
|
+
if image.ndim == 3
|
90
|
+
first_dim, last_dim = 0, 2
|
91
|
+
elsif image.ndim == 4
|
92
|
+
first_dim, last_dim = 1, 3
|
93
|
+
else
|
94
|
+
raise ArgumentError, "Unsupported number of image dimensions: #{image.ndim}"
|
95
|
+
end
|
96
|
+
|
97
|
+
if num_channels.include?(image.shape[first_dim]) && num_channels.include?(image.shape[last_dim])
|
98
|
+
Transformers.logger.warn(
|
99
|
+
"The channel dimension is ambiguous. Got image shape #{image.shape}. Assuming channels are the first dimension."
|
100
|
+
)
|
101
|
+
return ChannelDimension::FIRST
|
102
|
+
elsif num_channels.include?(image.shape[first_dim])
|
103
|
+
return ChannelDimension::FIRST
|
104
|
+
elsif num_channels.include?(image.shape[last_dim])
|
105
|
+
return ChannelDimension::LAST
|
106
|
+
end
|
107
|
+
raise ArgumentError, "Unable to infer channel dimension format"
|
108
|
+
end
|
109
|
+
|
110
|
+
def self.get_channel_dimension_axis(
|
111
|
+
image, input_data_format: nil
|
112
|
+
)
|
113
|
+
if input_data_format.nil?
|
114
|
+
input_data_format = infer_channel_dimension_format(image)
|
115
|
+
end
|
116
|
+
if input_data_format == ChannelDimension::FIRST
|
117
|
+
return image.ndim - 3
|
118
|
+
elsif input_data_format == ChannelDimension::LAST
|
119
|
+
return image.ndim - 1
|
120
|
+
end
|
121
|
+
raise ArgumentError, "Unsupported data format: #{input_data_format}"
|
122
|
+
end
|
123
|
+
|
124
|
+
def self.is_vips_image(img)
|
125
|
+
Utils.is_vision_available && img.is_a?(Vips::Image)
|
126
|
+
end
|
127
|
+
|
128
|
+
def self.is_valid_image(img)
|
129
|
+
is_vips_image(img) || is_numo_array(img) || is_torch_tensor(img)
|
130
|
+
end
|
131
|
+
|
132
|
+
def self.valid_images(imgs)
|
133
|
+
# If we have an list of images, make sure every image is valid
|
134
|
+
if imgs.is_a?(Array)
|
135
|
+
imgs.each do |img|
|
136
|
+
if !valid_images(img)
|
137
|
+
return false
|
138
|
+
end
|
139
|
+
end
|
140
|
+
# If not a list of tuple, we have been given a single image or batched tensor of images
|
141
|
+
elsif !is_valid_image(imgs)
|
142
|
+
return false
|
143
|
+
end
|
144
|
+
true
|
145
|
+
end
|
146
|
+
|
147
|
+
def self.is_scaled_image(image)
|
148
|
+
if image.is_a?(Numo::UInt8)
|
149
|
+
return false
|
150
|
+
end
|
151
|
+
|
152
|
+
# It's possible the image has pixel values in [0, 255] but is of floating type
|
153
|
+
image.min >= 0 && image.max <= 1
|
154
|
+
end
|
155
|
+
|
156
|
+
def self.validate_kwargs(valid_processor_keys:, captured_kwargs:)
|
157
|
+
unused_keys = Set.new(captured_kwargs).difference(Set.new(valid_processor_keys))
|
158
|
+
if unused_keys.any?
|
159
|
+
unused_key_str = unused_keys.join(", ")
|
160
|
+
# TODO raise a warning here instead of simply logging?
|
161
|
+
Transformers.logger.warn("Unused or unrecognized kwargs: #{unused_key_str}.")
|
162
|
+
end
|
163
|
+
end
|
164
|
+
end
|
165
|
+
end
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BaseModelOutput < ModelOutput
|
17
|
+
attribute :last_hidden_state
|
18
|
+
attribute :hidden_states
|
19
|
+
attribute :attentions
|
20
|
+
end
|
21
|
+
|
22
|
+
class BaseModelOutputWithPooling < ModelOutput
|
23
|
+
attribute :last_hidden_state
|
24
|
+
attribute :pooler_output
|
25
|
+
attribute :hidden_states
|
26
|
+
attribute :attentions
|
27
|
+
end
|
28
|
+
|
29
|
+
class BaseModelOutputWithPoolingAndCrossAttentions < ModelOutput
|
30
|
+
attribute :last_hidden_state
|
31
|
+
attribute :pooler_output
|
32
|
+
attribute :hidden_states
|
33
|
+
attribute :past_key_values
|
34
|
+
attribute :attentions
|
35
|
+
attribute :cross_attentions
|
36
|
+
end
|
37
|
+
|
38
|
+
class BaseModelOutputWithPastAndCrossAttentions < ModelOutput
|
39
|
+
attribute :last_hidden_state
|
40
|
+
attribute :past_key_values
|
41
|
+
attribute :hidden_states
|
42
|
+
attribute :attentions
|
43
|
+
attribute :cross_attentions
|
44
|
+
end
|
45
|
+
|
46
|
+
class MaskedLMOutput < ModelOutput
|
47
|
+
attribute :loss
|
48
|
+
attribute :logits
|
49
|
+
attribute :hidden_states
|
50
|
+
attribute :attentions
|
51
|
+
end
|
52
|
+
|
53
|
+
class SequenceClassifierOutput < ModelOutput
|
54
|
+
attribute :loss
|
55
|
+
attribute :logits
|
56
|
+
attribute :hidden_states
|
57
|
+
attribute :attentions
|
58
|
+
end
|
59
|
+
|
60
|
+
class TokenClassifierOutput < ModelOutput
|
61
|
+
attribute :loss
|
62
|
+
attribute :logits
|
63
|
+
attribute :hidden_states
|
64
|
+
attribute :attentions
|
65
|
+
end
|
66
|
+
|
67
|
+
class QuestionAnsweringModelOutput < ModelOutput
|
68
|
+
attribute :loss
|
69
|
+
attribute :start_logits
|
70
|
+
attribute :end_logits
|
71
|
+
attribute :hidden_states
|
72
|
+
attribute :attentions
|
73
|
+
end
|
74
|
+
|
75
|
+
class ImageClassifierOutput < ModelOutput
|
76
|
+
attribute :loss
|
77
|
+
attribute :logits
|
78
|
+
attribute :hidden_states
|
79
|
+
attribute :attentions
|
80
|
+
end
|
81
|
+
end
|