transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
module Transformers
|
2
|
+
module HfHub
|
3
|
+
class HfHubHTTPError < Error
|
4
|
+
def initialize(message, response = nil)
|
5
|
+
super(message)
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
class RepositoryNotFoundError < HfHubHTTPError; end
|
10
|
+
|
11
|
+
class GatedRepoError < RepositoryNotFoundError; end
|
12
|
+
|
13
|
+
class DisabledRepoError < HfHubHTTPError; end
|
14
|
+
|
15
|
+
class RevisionNotFoundError < HfHubHTTPError; end
|
16
|
+
|
17
|
+
class EntryNotFoundError < HfHubHTTPError; end
|
18
|
+
|
19
|
+
class LocalEntryNotFoundError < EntryNotFoundError; end
|
20
|
+
|
21
|
+
class BadRequestError < HfHubHTTPError; end
|
22
|
+
|
23
|
+
class << self
|
24
|
+
def hf_raise_for_status(response, endpoint_name: nil)
|
25
|
+
begin
|
26
|
+
response.value unless response.is_a?(Net::HTTPRedirection)
|
27
|
+
rescue
|
28
|
+
error_code = response["X-Error-Code"]
|
29
|
+
error_message = response["X-Error-Message"]
|
30
|
+
|
31
|
+
if error_code == "RevisionNotFound"
|
32
|
+
message = "#{response.code} Client Error." + "\n\n" + "Revision Not Found for url: #{response.uri}."
|
33
|
+
raise RevisionNotFoundError.new(message, response)
|
34
|
+
|
35
|
+
elsif error_code == "EntryNotFound"
|
36
|
+
message = "#{response.code} Client Error." + "\n\n" + "Entry Not Found for url: #{response.uri}."
|
37
|
+
raise EntryNotFoundError.new(message, response)
|
38
|
+
|
39
|
+
elsif error_code == "GatedRepo"
|
40
|
+
message = (
|
41
|
+
"#{response.code} Client Error." + "\n\n" + "Cannot access gated repo for url #{response.uri}."
|
42
|
+
)
|
43
|
+
raise GatedRepoError.new(message, response)
|
44
|
+
|
45
|
+
elsif error_message == "Access to this resource is disabled."
|
46
|
+
message = (
|
47
|
+
"#{response.code} Client Error." +
|
48
|
+
"\n\n" +
|
49
|
+
"Cannot access repository for url #{response.uri}." +
|
50
|
+
"\n" +
|
51
|
+
"Access to this resource is disabled."
|
52
|
+
)
|
53
|
+
raise DisabledRepoError.new(message, response)
|
54
|
+
|
55
|
+
elsif error_code == "RepoNotFound"
|
56
|
+
# 401 is misleading as it is returned for:
|
57
|
+
# - private and gated repos if user is not authenticated
|
58
|
+
# - missing repos
|
59
|
+
# => for now, we process them as `RepoNotFound` anyway.
|
60
|
+
# See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
|
61
|
+
message = (
|
62
|
+
"#{response.code} Client Error." +
|
63
|
+
"\n\n" +
|
64
|
+
"Repository Not Found for url: #{response.uri}." +
|
65
|
+
"\nPlease make sure you specified the correct `repo_id` and" +
|
66
|
+
" `repo_type`.\nIf you are trying to access a private or gated repo," +
|
67
|
+
" make sure you are authenticated."
|
68
|
+
)
|
69
|
+
raise RepositoryNotFoundError.new(message, response)
|
70
|
+
|
71
|
+
elsif response.code.to_i == 400
|
72
|
+
message = (
|
73
|
+
!endpoint_name.nil? ? "\n\nBad request for #{endpoint_name} endpoint:" : "\n\nBad request:"
|
74
|
+
)
|
75
|
+
raise BadRequestError.new(message, response)
|
76
|
+
|
77
|
+
elsif response.code.to_i == 403
|
78
|
+
message = (
|
79
|
+
"\n\n{response.code} Forbidden: #{error_message}." +
|
80
|
+
"\nCannot access content at: #{response.uri}." +
|
81
|
+
"\nIf you are trying to create or update content, " +
|
82
|
+
"make sure you have a token with the `write` role."
|
83
|
+
)
|
84
|
+
raise HfHubHTTPError.new(message, response)
|
85
|
+
end
|
86
|
+
|
87
|
+
# Convert `HTTPError` into a `HfHubHTTPError` to display request information
|
88
|
+
# as well (request id and/or server error message)
|
89
|
+
raise HfHubHTTPError.new(e.to_s, response)
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2022-present, the HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module HfHub
|
17
|
+
class << self
|
18
|
+
def build_hf_headers(
|
19
|
+
token: nil,
|
20
|
+
is_write_action: false,
|
21
|
+
library_name: nil,
|
22
|
+
library_version: nil,
|
23
|
+
user_agent: nil,
|
24
|
+
headers: nil
|
25
|
+
)
|
26
|
+
# Get auth token to send
|
27
|
+
token_to_send = get_token_to_send(token)
|
28
|
+
_validate_token_to_send(token_to_send, is_write_action)
|
29
|
+
|
30
|
+
# Combine headers
|
31
|
+
hf_headers = {
|
32
|
+
"user-agent" => _http_user_agent(
|
33
|
+
library_name: library_name,
|
34
|
+
library_version: library_version,
|
35
|
+
user_agent: user_agent
|
36
|
+
)
|
37
|
+
}
|
38
|
+
if !token_to_send.nil?
|
39
|
+
hf_headers["authorization"] = "Bearer #{token_to_send}"
|
40
|
+
end
|
41
|
+
if headers
|
42
|
+
hf_headers.merge!(headers)
|
43
|
+
end
|
44
|
+
hf_headers
|
45
|
+
end
|
46
|
+
|
47
|
+
def get_token_to_send(token)
|
48
|
+
# Case token is explicitly provided
|
49
|
+
if token.is_a?(String)
|
50
|
+
return token
|
51
|
+
end
|
52
|
+
|
53
|
+
# Case token is explicitly forbidden
|
54
|
+
if token == false
|
55
|
+
return nil
|
56
|
+
end
|
57
|
+
|
58
|
+
# Token is not provided: we get it from local cache
|
59
|
+
cached_token = nil # get_token
|
60
|
+
|
61
|
+
# Case token is explicitly required
|
62
|
+
if token == true
|
63
|
+
if cached_token.nil?
|
64
|
+
raise LocalTokenNotFoundError,
|
65
|
+
"Token is required (`token: true`), but no token found. You" +
|
66
|
+
" need to provide a token or be logged in to Hugging Face with" +
|
67
|
+
" `huggingface-cli login` or `huggingface_hub.login`. See" +
|
68
|
+
" https://huggingface.co/settings/tokens."
|
69
|
+
end
|
70
|
+
return cached_token
|
71
|
+
end
|
72
|
+
|
73
|
+
# Case implicit use of the token is forbidden by env variable
|
74
|
+
if HF_HUB_DISABLE_IMPLICIT_TOKEN
|
75
|
+
return nil
|
76
|
+
end
|
77
|
+
|
78
|
+
# Otherwise: we use the cached token as the user has not explicitly forbidden it
|
79
|
+
cached_token
|
80
|
+
end
|
81
|
+
|
82
|
+
def _validate_token_to_send(token, is_write_action)
|
83
|
+
if is_write_action
|
84
|
+
if token.nil?
|
85
|
+
raise ArgumentError,
|
86
|
+
"Token is required (write-access action) but no token found. You need" +
|
87
|
+
" to provide a token or be logged in to Hugging Face with" +
|
88
|
+
" `huggingface-cli login` or `huggingface_hub.login`. See" +
|
89
|
+
" https://huggingface.co/settings/tokens."
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
def _http_user_agent(
|
95
|
+
library_name: nil,
|
96
|
+
library_version: nil,
|
97
|
+
user_agent: nil
|
98
|
+
)
|
99
|
+
if !library_name.nil?
|
100
|
+
ua = "#{library_name}/#{library_version}"
|
101
|
+
else
|
102
|
+
ua = "unknown/None"
|
103
|
+
end
|
104
|
+
ua += "; ruby/#{RUBY_VERSION.to_f}"
|
105
|
+
ua
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|
109
|
+
end
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class ImageProcessingMixin
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
cache_dir: nil,
|
20
|
+
force_download: false,
|
21
|
+
local_files_only: false,
|
22
|
+
token: nil,
|
23
|
+
revision: "main",
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
kwargs[:cache_dir] = cache_dir
|
27
|
+
kwargs[:force_download] = force_download
|
28
|
+
kwargs[:local_files_only] = local_files_only
|
29
|
+
kwargs[:revision] = revision
|
30
|
+
|
31
|
+
if !token.nil?
|
32
|
+
kwargs[:token] = token
|
33
|
+
end
|
34
|
+
|
35
|
+
image_processor_dict, kwargs = get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
36
|
+
|
37
|
+
from_dict(image_processor_dict, **kwargs)
|
38
|
+
end
|
39
|
+
|
40
|
+
def self.get_image_processor_dict(
|
41
|
+
pretrained_model_name_or_path, **kwargs
|
42
|
+
)
|
43
|
+
cache_dir = kwargs.delete(:cache_dir)
|
44
|
+
force_download = kwargs.delete(:force_download) { false }
|
45
|
+
resume_download = kwargs.delete(:resume_download)
|
46
|
+
proxies = kwargs.delete(:proxies)
|
47
|
+
token = kwargs.delete(:token)
|
48
|
+
_use_auth_token = kwargs.delete(:use_auth_token)
|
49
|
+
local_files_only = kwargs.delete(:local_files_only) { false }
|
50
|
+
revision = kwargs.delete(:revision)
|
51
|
+
subfolder = kwargs.delete(:subfolder) { "" }
|
52
|
+
|
53
|
+
from_pipeline = kwargs.delete(:_from_pipeline)
|
54
|
+
from_auto_class = kwargs.delete(:_from_auto) { false }
|
55
|
+
|
56
|
+
user_agent = {file_type: "image processor", from_auto_class: from_auto_class}
|
57
|
+
if !from_pipeline.nil?
|
58
|
+
user_agent[:using_pipeline] = from_pipeline
|
59
|
+
end
|
60
|
+
|
61
|
+
if Utils::Hub.is_offline_mode && !local_files_only
|
62
|
+
Transformers.logger.info("Offline mode: forcing local_files_only: true")
|
63
|
+
local_files_only = true
|
64
|
+
end
|
65
|
+
|
66
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
|
67
|
+
is_local = Dir.exist?(pretrained_model_name_or_path)
|
68
|
+
if Dir.exist?(pretrained_model_name_or_path)
|
69
|
+
image_processor_file = File.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
|
70
|
+
end
|
71
|
+
if File.exist?(pretrained_model_name_or_path)
|
72
|
+
resolved_image_processor_file = pretrained_model_name_or_path
|
73
|
+
is_local = true
|
74
|
+
elsif Utils::Hub.is_remote_url(pretrained_model_name_or_path)
|
75
|
+
raise Todo
|
76
|
+
else
|
77
|
+
image_processor_file = IMAGE_PROCESSOR_NAME
|
78
|
+
begin
|
79
|
+
# Load from local folder or from cache or download from model Hub and cache
|
80
|
+
resolved_image_processor_file = Utils::Hub.cached_file(
|
81
|
+
pretrained_model_name_or_path,
|
82
|
+
image_processor_file,
|
83
|
+
cache_dir: cache_dir,
|
84
|
+
force_download: force_download,
|
85
|
+
proxies: proxies,
|
86
|
+
resume_download: resume_download,
|
87
|
+
local_files_only: local_files_only,
|
88
|
+
token: token,
|
89
|
+
user_agent: user_agent,
|
90
|
+
revision: revision,
|
91
|
+
subfolder: subfolder
|
92
|
+
)
|
93
|
+
rescue EnvironmentError
|
94
|
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
95
|
+
# the original exception.
|
96
|
+
raise
|
97
|
+
rescue
|
98
|
+
# For any other exception, we throw a generic error.
|
99
|
+
raise EnvironmentError,
|
100
|
+
"Can't load image processor for '#{pretrained_model_name_or_path}'. If you were trying to load" +
|
101
|
+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the" +
|
102
|
+
" same name. Otherwise, make sure '#{pretrained_model_name_or_path}' is the correct path to a" +
|
103
|
+
" directory containing a #{IMAGE_PROCESSOR_NAME} file"
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
begin
|
108
|
+
image_processor_dict = JSON.load_file(resolved_image_processor_file).transform_keys(&:to_sym)
|
109
|
+
rescue JSON::ParserError
|
110
|
+
raise EnvironmentError,
|
111
|
+
"It looks like the config file at '#{resolved_image_processor_file}' is not a valid JSON file."
|
112
|
+
end
|
113
|
+
|
114
|
+
if is_local
|
115
|
+
Transformers.logger.info("loading configuration file #{resolved_image_processor_file}")
|
116
|
+
else
|
117
|
+
Transformers.logger.info(
|
118
|
+
"loading configuration file #{image_processor_file} from cache at #{resolved_image_processor_file}"
|
119
|
+
)
|
120
|
+
end
|
121
|
+
|
122
|
+
if !is_local
|
123
|
+
if image_processor_dict.include?("auto_map")
|
124
|
+
raise Todo
|
125
|
+
end
|
126
|
+
if image_processor_dict.include?("custom_pipelines")
|
127
|
+
raise Todo
|
128
|
+
end
|
129
|
+
end
|
130
|
+
[image_processor_dict, kwargs]
|
131
|
+
end
|
132
|
+
|
133
|
+
def self.from_dict(image_processor_dict, **kwargs)
|
134
|
+
image_processor_dict = image_processor_dict.dup
|
135
|
+
return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
|
136
|
+
|
137
|
+
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
|
138
|
+
# We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
|
139
|
+
# dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
|
140
|
+
if kwargs.include?(:size) && image_processor_dict.include?(:size)
|
141
|
+
image_processor_dict[:size] = kwargs.delete(:size)
|
142
|
+
end
|
143
|
+
if kwargs.include?(:crop_size) && image_processor_dict.include?(:crop_size)
|
144
|
+
image_processor_dict[:crop_size] = kwargs.delete(:crop_size)
|
145
|
+
end
|
146
|
+
|
147
|
+
image_processor = new(**image_processor_dict)
|
148
|
+
|
149
|
+
# Update image_processor with kwargs if needed
|
150
|
+
to_remove = []
|
151
|
+
kwargs.each do |key, value|
|
152
|
+
if image_processor.instance_variable_defined?("@#{key}")
|
153
|
+
image_processor.instance_variable_set("@#{key}", value)
|
154
|
+
to_remove << key
|
155
|
+
end
|
156
|
+
end
|
157
|
+
to_remove.each do |key|
|
158
|
+
kwargs.delete(key)
|
159
|
+
end
|
160
|
+
|
161
|
+
Transformers.logger.info("Image processor #{image_processor}")
|
162
|
+
if return_unused_kwargs
|
163
|
+
[image_processor, kwargs]
|
164
|
+
else
|
165
|
+
image_processor
|
166
|
+
end
|
167
|
+
end
|
168
|
+
end
|
169
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BaseImageProcessor < ImageProcessingMixin
|
17
|
+
def initialize(**kwargs)
|
18
|
+
super(**kwargs)
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(images, **kwargs)
|
22
|
+
preprocess(images, **kwargs)
|
23
|
+
end
|
24
|
+
|
25
|
+
def preprocess(images, **kwargs)
|
26
|
+
raise NotImplementedError, "Each image processor must implement its own preprocess method"
|
27
|
+
end
|
28
|
+
|
29
|
+
def rescale(
|
30
|
+
image,
|
31
|
+
scale,
|
32
|
+
data_format: nil,
|
33
|
+
input_data_format: nil,
|
34
|
+
**kwargs
|
35
|
+
)
|
36
|
+
ImageTransforms.rescale(image, scale, data_format: data_format, input_data_format: input_data_format, **kwargs)
|
37
|
+
end
|
38
|
+
|
39
|
+
def normalize(
|
40
|
+
image,
|
41
|
+
mean,
|
42
|
+
std,
|
43
|
+
data_format: nil,
|
44
|
+
input_data_format: nil,
|
45
|
+
**kwargs
|
46
|
+
)
|
47
|
+
ImageTransforms.normalize(
|
48
|
+
image, mean, std, data_format: data_format, input_data_format: input_data_format, **kwargs
|
49
|
+
)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
module ImageProcessingUtils
|
54
|
+
def self.get_size_dict(size)
|
55
|
+
if !size.is_a?(Hash)
|
56
|
+
size_dict = {height: size, width: size}
|
57
|
+
else
|
58
|
+
size_dict = size
|
59
|
+
end
|
60
|
+
size_dict
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module ImageTransforms
|
17
|
+
def self.to_channel_dimension_format(
|
18
|
+
image,
|
19
|
+
channel_dim,
|
20
|
+
input_channel_dim: nil
|
21
|
+
)
|
22
|
+
if !image.is_a?(Numo::NArray)
|
23
|
+
raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
|
24
|
+
end
|
25
|
+
|
26
|
+
if input_channel_dim.nil?
|
27
|
+
input_channel_dim = infer_channel_dimension_format(image)
|
28
|
+
end
|
29
|
+
|
30
|
+
target_channel_dim = ChannelDimension.new(channel_dim).to_s
|
31
|
+
if input_channel_dim == target_channel_dim
|
32
|
+
return image
|
33
|
+
end
|
34
|
+
|
35
|
+
if target_channel_dim == ChannelDimension::FIRST
|
36
|
+
image = image.transpose(2, 0, 1)
|
37
|
+
elsif target_channel_dim == ChannelDimension::LAST
|
38
|
+
image = image.transpose(1, 2, 0)
|
39
|
+
else
|
40
|
+
raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}"
|
41
|
+
end
|
42
|
+
|
43
|
+
image
|
44
|
+
end
|
45
|
+
|
46
|
+
def self.rescale(
|
47
|
+
image,
|
48
|
+
scale,
|
49
|
+
data_format: nil,
|
50
|
+
dtype: Numo::SFloat,
|
51
|
+
input_data_format: nil
|
52
|
+
)
|
53
|
+
if !image.is_a?(Numo::NArray)
|
54
|
+
raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
|
55
|
+
end
|
56
|
+
|
57
|
+
rescaled_image = image * scale
|
58
|
+
if !data_format.nil?
|
59
|
+
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
|
60
|
+
end
|
61
|
+
|
62
|
+
rescaled_image = rescaled_image.cast_to(dtype)
|
63
|
+
|
64
|
+
rescaled_image
|
65
|
+
end
|
66
|
+
|
67
|
+
def self.resize(
|
68
|
+
image,
|
69
|
+
size,
|
70
|
+
resample: nil,
|
71
|
+
reducing_gap: nil,
|
72
|
+
data_format: nil,
|
73
|
+
return_numpy: true,
|
74
|
+
input_data_format: nil
|
75
|
+
)
|
76
|
+
resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR
|
77
|
+
|
78
|
+
if size.length != 2
|
79
|
+
raise ArgumentError, "size must have 2 elements"
|
80
|
+
end
|
81
|
+
|
82
|
+
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
83
|
+
# The resized image from PIL will always have channels last, so find the input format first.
|
84
|
+
if input_data_format.nil?
|
85
|
+
input_data_format = ImageUtils.infer_channel_dimension_format(image)
|
86
|
+
end
|
87
|
+
data_format = data_format.nil? ? input_data_format : data_format
|
88
|
+
|
89
|
+
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
90
|
+
# the pillow library to resize the image and then convert back to numpy
|
91
|
+
do_rescale = false
|
92
|
+
if !image.is_a?(Vips::Image)
|
93
|
+
do_rescale = _rescale_for_pil_conversion(image)
|
94
|
+
image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format)
|
95
|
+
end
|
96
|
+
height, width = size
|
97
|
+
# TODO support resample
|
98
|
+
resized_image = image.thumbnail_image(width, height: height, size: :force)
|
99
|
+
|
100
|
+
if return_numpy
|
101
|
+
resized_image = ImageUtils.to_numo_array(resized_image)
|
102
|
+
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
103
|
+
# so we need to add it back if necessary.
|
104
|
+
resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image
|
105
|
+
# The image is always in channels last format after converting from a PIL image
|
106
|
+
resized_image = to_channel_dimension_format(
|
107
|
+
resized_image, data_format, input_channel_dim: ChannelDimension::LAST
|
108
|
+
)
|
109
|
+
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
|
110
|
+
# rescale it back to the original range.
|
111
|
+
resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image
|
112
|
+
end
|
113
|
+
resized_image
|
114
|
+
end
|
115
|
+
|
116
|
+
def self.normalize(
|
117
|
+
image,
|
118
|
+
mean,
|
119
|
+
std,
|
120
|
+
data_format: nil,
|
121
|
+
input_data_format: nil
|
122
|
+
)
|
123
|
+
if !image.is_a?(Numo::NArray)
|
124
|
+
raise ArgumentError, "image must be a numpy array"
|
125
|
+
end
|
126
|
+
|
127
|
+
if input_data_format.nil?
|
128
|
+
input_data_format = infer_channel_dimension_format(image)
|
129
|
+
end
|
130
|
+
|
131
|
+
channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format)
|
132
|
+
num_channels = image.shape[channel_axis]
|
133
|
+
|
134
|
+
# We cast to float32 to avoid errors that can occur when subtracting uint8 values.
|
135
|
+
# We preserve the original dtype if it is a float type to prevent upcasting float16.
|
136
|
+
if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat)
|
137
|
+
image = image.cast_to(Numo::SFloat)
|
138
|
+
end
|
139
|
+
|
140
|
+
if mean.is_a?(Enumerable)
|
141
|
+
if mean.length != num_channels
|
142
|
+
raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}"
|
143
|
+
end
|
144
|
+
else
|
145
|
+
mean = [mean] * num_channels
|
146
|
+
end
|
147
|
+
mean = Numo::DFloat.cast(mean)
|
148
|
+
|
149
|
+
if std.is_a?(Enumerable)
|
150
|
+
if std.length != num_channels
|
151
|
+
raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}"
|
152
|
+
end
|
153
|
+
else
|
154
|
+
std = [std] * num_channels
|
155
|
+
end
|
156
|
+
std = Numo::DFloat.cast(std)
|
157
|
+
|
158
|
+
if input_data_format == ChannelDimension::LAST
|
159
|
+
image = (image - mean) / std
|
160
|
+
else
|
161
|
+
image = ((image.transpose - mean) / std).transpose
|
162
|
+
end
|
163
|
+
|
164
|
+
image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image
|
165
|
+
image
|
166
|
+
end
|
167
|
+
|
168
|
+
def self.to_pil_image(
|
169
|
+
image,
|
170
|
+
do_rescale: nil,
|
171
|
+
input_data_format: nil
|
172
|
+
)
|
173
|
+
if image.is_a?(Vips::Image)
|
174
|
+
return image
|
175
|
+
end
|
176
|
+
|
177
|
+
# Convert all tensors to numo arrays before converting to Vips image
|
178
|
+
if !image.is_a?(Numo::NArray)
|
179
|
+
raise ArgumentError, "Input image type not supported: #{image.class.name}"
|
180
|
+
end
|
181
|
+
|
182
|
+
# If the channel has been moved to first dim, we put it back at the end.
|
183
|
+
image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format)
|
184
|
+
|
185
|
+
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
186
|
+
# image = image.shape[-1] == 1 ? image.squeeze(-1) : image
|
187
|
+
|
188
|
+
# Rescale the image to be between 0 and 255 if needed.
|
189
|
+
do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale
|
190
|
+
|
191
|
+
if do_rescale
|
192
|
+
image = rescale(image, 255)
|
193
|
+
end
|
194
|
+
|
195
|
+
image = image.cast_to(Numo::UInt8)
|
196
|
+
Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar)
|
197
|
+
end
|
198
|
+
|
199
|
+
def self._rescale_for_pil_conversion(image)
|
200
|
+
if image.is_a?(Numo::UInt8)
|
201
|
+
do_rescale = false
|
202
|
+
else
|
203
|
+
raise Todo
|
204
|
+
end
|
205
|
+
do_rescale
|
206
|
+
end
|
207
|
+
end
|
208
|
+
end
|