transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,94 @@
|
|
1
|
+
module Transformers
|
2
|
+
module HfHub
|
3
|
+
class HfHubHTTPError < Error
|
4
|
+
def initialize(message, response = nil)
|
5
|
+
super(message)
|
6
|
+
end
|
7
|
+
end
|
8
|
+
|
9
|
+
class RepositoryNotFoundError < HfHubHTTPError; end
|
10
|
+
|
11
|
+
class GatedRepoError < RepositoryNotFoundError; end
|
12
|
+
|
13
|
+
class DisabledRepoError < HfHubHTTPError; end
|
14
|
+
|
15
|
+
class RevisionNotFoundError < HfHubHTTPError; end
|
16
|
+
|
17
|
+
class EntryNotFoundError < HfHubHTTPError; end
|
18
|
+
|
19
|
+
class LocalEntryNotFoundError < EntryNotFoundError; end
|
20
|
+
|
21
|
+
class BadRequestError < HfHubHTTPError; end
|
22
|
+
|
23
|
+
class << self
|
24
|
+
def hf_raise_for_status(response, endpoint_name: nil)
|
25
|
+
begin
|
26
|
+
response.value unless response.is_a?(Net::HTTPRedirection)
|
27
|
+
rescue
|
28
|
+
error_code = response["X-Error-Code"]
|
29
|
+
error_message = response["X-Error-Message"]
|
30
|
+
|
31
|
+
if error_code == "RevisionNotFound"
|
32
|
+
message = "#{response.code} Client Error." + "\n\n" + "Revision Not Found for url: #{response.uri}."
|
33
|
+
raise RevisionNotFoundError.new(message, response)
|
34
|
+
|
35
|
+
elsif error_code == "EntryNotFound"
|
36
|
+
message = "#{response.code} Client Error." + "\n\n" + "Entry Not Found for url: #{response.uri}."
|
37
|
+
raise EntryNotFoundError.new(message, response)
|
38
|
+
|
39
|
+
elsif error_code == "GatedRepo"
|
40
|
+
message = (
|
41
|
+
"#{response.code} Client Error." + "\n\n" + "Cannot access gated repo for url #{response.uri}."
|
42
|
+
)
|
43
|
+
raise GatedRepoError.new(message, response)
|
44
|
+
|
45
|
+
elsif error_message == "Access to this resource is disabled."
|
46
|
+
message = (
|
47
|
+
"#{response.code} Client Error." +
|
48
|
+
"\n\n" +
|
49
|
+
"Cannot access repository for url #{response.uri}." +
|
50
|
+
"\n" +
|
51
|
+
"Access to this resource is disabled."
|
52
|
+
)
|
53
|
+
raise DisabledRepoError.new(message, response)
|
54
|
+
|
55
|
+
elsif error_code == "RepoNotFound"
|
56
|
+
# 401 is misleading as it is returned for:
|
57
|
+
# - private and gated repos if user is not authenticated
|
58
|
+
# - missing repos
|
59
|
+
# => for now, we process them as `RepoNotFound` anyway.
|
60
|
+
# See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
|
61
|
+
message = (
|
62
|
+
"#{response.code} Client Error." +
|
63
|
+
"\n\n" +
|
64
|
+
"Repository Not Found for url: #{response.uri}." +
|
65
|
+
"\nPlease make sure you specified the correct `repo_id` and" +
|
66
|
+
" `repo_type`.\nIf you are trying to access a private or gated repo," +
|
67
|
+
" make sure you are authenticated."
|
68
|
+
)
|
69
|
+
raise RepositoryNotFoundError.new(message, response)
|
70
|
+
|
71
|
+
elsif response.code.to_i == 400
|
72
|
+
message = (
|
73
|
+
!endpoint_name.nil? ? "\n\nBad request for #{endpoint_name} endpoint:" : "\n\nBad request:"
|
74
|
+
)
|
75
|
+
raise BadRequestError.new(message, response)
|
76
|
+
|
77
|
+
elsif response.code.to_i == 403
|
78
|
+
message = (
|
79
|
+
"\n\n{response.code} Forbidden: #{error_message}." +
|
80
|
+
"\nCannot access content at: #{response.uri}." +
|
81
|
+
"\nIf you are trying to create or update content, " +
|
82
|
+
"make sure you have a token with the `write` role."
|
83
|
+
)
|
84
|
+
raise HfHubHTTPError.new(message, response)
|
85
|
+
end
|
86
|
+
|
87
|
+
# Convert `HTTPError` into a `HfHubHTTPError` to display request information
|
88
|
+
# as well (request id and/or server error message)
|
89
|
+
raise HfHubHTTPError.new(e.to_s, response)
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2022-present, the HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module HfHub
|
17
|
+
class << self
|
18
|
+
def build_hf_headers(
|
19
|
+
token: nil,
|
20
|
+
is_write_action: false,
|
21
|
+
library_name: nil,
|
22
|
+
library_version: nil,
|
23
|
+
user_agent: nil,
|
24
|
+
headers: nil
|
25
|
+
)
|
26
|
+
# Get auth token to send
|
27
|
+
token_to_send = get_token_to_send(token)
|
28
|
+
_validate_token_to_send(token_to_send, is_write_action)
|
29
|
+
|
30
|
+
# Combine headers
|
31
|
+
hf_headers = {
|
32
|
+
"user-agent" => _http_user_agent(
|
33
|
+
library_name: library_name,
|
34
|
+
library_version: library_version,
|
35
|
+
user_agent: user_agent
|
36
|
+
)
|
37
|
+
}
|
38
|
+
if !token_to_send.nil?
|
39
|
+
hf_headers["authorization"] = "Bearer #{token_to_send}"
|
40
|
+
end
|
41
|
+
if headers
|
42
|
+
hf_headers.merge!(headers)
|
43
|
+
end
|
44
|
+
hf_headers
|
45
|
+
end
|
46
|
+
|
47
|
+
def get_token_to_send(token)
|
48
|
+
# Case token is explicitly provided
|
49
|
+
if token.is_a?(String)
|
50
|
+
return token
|
51
|
+
end
|
52
|
+
|
53
|
+
# Case token is explicitly forbidden
|
54
|
+
if token == false
|
55
|
+
return nil
|
56
|
+
end
|
57
|
+
|
58
|
+
# Token is not provided: we get it from local cache
|
59
|
+
cached_token = nil # get_token
|
60
|
+
|
61
|
+
# Case token is explicitly required
|
62
|
+
if token == true
|
63
|
+
if cached_token.nil?
|
64
|
+
raise LocalTokenNotFoundError,
|
65
|
+
"Token is required (`token: true`), but no token found. You" +
|
66
|
+
" need to provide a token or be logged in to Hugging Face with" +
|
67
|
+
" `huggingface-cli login` or `huggingface_hub.login`. See" +
|
68
|
+
" https://huggingface.co/settings/tokens."
|
69
|
+
end
|
70
|
+
return cached_token
|
71
|
+
end
|
72
|
+
|
73
|
+
# Case implicit use of the token is forbidden by env variable
|
74
|
+
if HF_HUB_DISABLE_IMPLICIT_TOKEN
|
75
|
+
return nil
|
76
|
+
end
|
77
|
+
|
78
|
+
# Otherwise: we use the cached token as the user has not explicitly forbidden it
|
79
|
+
cached_token
|
80
|
+
end
|
81
|
+
|
82
|
+
def _validate_token_to_send(token, is_write_action)
|
83
|
+
if is_write_action
|
84
|
+
if token.nil?
|
85
|
+
raise ArgumentError,
|
86
|
+
"Token is required (write-access action) but no token found. You need" +
|
87
|
+
" to provide a token or be logged in to Hugging Face with" +
|
88
|
+
" `huggingface-cli login` or `huggingface_hub.login`. See" +
|
89
|
+
" https://huggingface.co/settings/tokens."
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
|
94
|
+
def _http_user_agent(
|
95
|
+
library_name: nil,
|
96
|
+
library_version: nil,
|
97
|
+
user_agent: nil
|
98
|
+
)
|
99
|
+
if !library_name.nil?
|
100
|
+
ua = "#{library_name}/#{library_version}"
|
101
|
+
else
|
102
|
+
ua = "unknown/None"
|
103
|
+
end
|
104
|
+
ua += "; ruby/#{RUBY_VERSION.to_f}"
|
105
|
+
ua
|
106
|
+
end
|
107
|
+
end
|
108
|
+
end
|
109
|
+
end
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class ImageProcessingMixin
|
17
|
+
def self.from_pretrained(
|
18
|
+
pretrained_model_name_or_path,
|
19
|
+
cache_dir: nil,
|
20
|
+
force_download: false,
|
21
|
+
local_files_only: false,
|
22
|
+
token: nil,
|
23
|
+
revision: "main",
|
24
|
+
**kwargs
|
25
|
+
)
|
26
|
+
kwargs[:cache_dir] = cache_dir
|
27
|
+
kwargs[:force_download] = force_download
|
28
|
+
kwargs[:local_files_only] = local_files_only
|
29
|
+
kwargs[:revision] = revision
|
30
|
+
|
31
|
+
if !token.nil?
|
32
|
+
kwargs[:token] = token
|
33
|
+
end
|
34
|
+
|
35
|
+
image_processor_dict, kwargs = get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
|
36
|
+
|
37
|
+
from_dict(image_processor_dict, **kwargs)
|
38
|
+
end
|
39
|
+
|
40
|
+
def self.get_image_processor_dict(
|
41
|
+
pretrained_model_name_or_path, **kwargs
|
42
|
+
)
|
43
|
+
cache_dir = kwargs.delete(:cache_dir)
|
44
|
+
force_download = kwargs.delete(:force_download) { false }
|
45
|
+
resume_download = kwargs.delete(:resume_download)
|
46
|
+
proxies = kwargs.delete(:proxies)
|
47
|
+
token = kwargs.delete(:token)
|
48
|
+
_use_auth_token = kwargs.delete(:use_auth_token)
|
49
|
+
local_files_only = kwargs.delete(:local_files_only) { false }
|
50
|
+
revision = kwargs.delete(:revision)
|
51
|
+
subfolder = kwargs.delete(:subfolder) { "" }
|
52
|
+
|
53
|
+
from_pipeline = kwargs.delete(:_from_pipeline)
|
54
|
+
from_auto_class = kwargs.delete(:_from_auto) { false }
|
55
|
+
|
56
|
+
user_agent = {file_type: "image processor", from_auto_class: from_auto_class}
|
57
|
+
if !from_pipeline.nil?
|
58
|
+
user_agent[:using_pipeline] = from_pipeline
|
59
|
+
end
|
60
|
+
|
61
|
+
if Utils::Hub.is_offline_mode && !local_files_only
|
62
|
+
Transformers.logger.info("Offline mode: forcing local_files_only: true")
|
63
|
+
local_files_only = true
|
64
|
+
end
|
65
|
+
|
66
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
|
67
|
+
is_local = Dir.exist?(pretrained_model_name_or_path)
|
68
|
+
if Dir.exist?(pretrained_model_name_or_path)
|
69
|
+
image_processor_file = File.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
|
70
|
+
end
|
71
|
+
if File.exist?(pretrained_model_name_or_path)
|
72
|
+
resolved_image_processor_file = pretrained_model_name_or_path
|
73
|
+
is_local = true
|
74
|
+
elsif Utils::Hub.is_remote_url(pretrained_model_name_or_path)
|
75
|
+
raise Todo
|
76
|
+
else
|
77
|
+
image_processor_file = IMAGE_PROCESSOR_NAME
|
78
|
+
begin
|
79
|
+
# Load from local folder or from cache or download from model Hub and cache
|
80
|
+
resolved_image_processor_file = Utils::Hub.cached_file(
|
81
|
+
pretrained_model_name_or_path,
|
82
|
+
image_processor_file,
|
83
|
+
cache_dir: cache_dir,
|
84
|
+
force_download: force_download,
|
85
|
+
proxies: proxies,
|
86
|
+
resume_download: resume_download,
|
87
|
+
local_files_only: local_files_only,
|
88
|
+
token: token,
|
89
|
+
user_agent: user_agent,
|
90
|
+
revision: revision,
|
91
|
+
subfolder: subfolder
|
92
|
+
)
|
93
|
+
rescue EnvironmentError
|
94
|
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
95
|
+
# the original exception.
|
96
|
+
raise
|
97
|
+
rescue
|
98
|
+
# For any other exception, we throw a generic error.
|
99
|
+
raise EnvironmentError,
|
100
|
+
"Can't load image processor for '#{pretrained_model_name_or_path}'. If you were trying to load" +
|
101
|
+
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the" +
|
102
|
+
" same name. Otherwise, make sure '#{pretrained_model_name_or_path}' is the correct path to a" +
|
103
|
+
" directory containing a #{IMAGE_PROCESSOR_NAME} file"
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
begin
|
108
|
+
image_processor_dict = JSON.load_file(resolved_image_processor_file).transform_keys(&:to_sym)
|
109
|
+
rescue JSON::ParserError
|
110
|
+
raise EnvironmentError,
|
111
|
+
"It looks like the config file at '#{resolved_image_processor_file}' is not a valid JSON file."
|
112
|
+
end
|
113
|
+
|
114
|
+
if is_local
|
115
|
+
Transformers.logger.info("loading configuration file #{resolved_image_processor_file}")
|
116
|
+
else
|
117
|
+
Transformers.logger.info(
|
118
|
+
"loading configuration file #{image_processor_file} from cache at #{resolved_image_processor_file}"
|
119
|
+
)
|
120
|
+
end
|
121
|
+
|
122
|
+
if !is_local
|
123
|
+
if image_processor_dict.include?("auto_map")
|
124
|
+
raise Todo
|
125
|
+
end
|
126
|
+
if image_processor_dict.include?("custom_pipelines")
|
127
|
+
raise Todo
|
128
|
+
end
|
129
|
+
end
|
130
|
+
[image_processor_dict, kwargs]
|
131
|
+
end
|
132
|
+
|
133
|
+
def self.from_dict(image_processor_dict, **kwargs)
|
134
|
+
image_processor_dict = image_processor_dict.dup
|
135
|
+
return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
|
136
|
+
|
137
|
+
# The `size` parameter is a dict and was previously an int or tuple in feature extractors.
|
138
|
+
# We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
|
139
|
+
# dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
|
140
|
+
if kwargs.include?(:size) && image_processor_dict.include?(:size)
|
141
|
+
image_processor_dict[:size] = kwargs.delete(:size)
|
142
|
+
end
|
143
|
+
if kwargs.include?(:crop_size) && image_processor_dict.include?(:crop_size)
|
144
|
+
image_processor_dict[:crop_size] = kwargs.delete(:crop_size)
|
145
|
+
end
|
146
|
+
|
147
|
+
image_processor = new(**image_processor_dict)
|
148
|
+
|
149
|
+
# Update image_processor with kwargs if needed
|
150
|
+
to_remove = []
|
151
|
+
kwargs.each do |key, value|
|
152
|
+
if image_processor.instance_variable_defined?("@#{key}")
|
153
|
+
image_processor.instance_variable_set("@#{key}", value)
|
154
|
+
to_remove << key
|
155
|
+
end
|
156
|
+
end
|
157
|
+
to_remove.each do |key|
|
158
|
+
kwargs.delete(key)
|
159
|
+
end
|
160
|
+
|
161
|
+
Transformers.logger.info("Image processor #{image_processor}")
|
162
|
+
if return_unused_kwargs
|
163
|
+
[image_processor, kwargs]
|
164
|
+
else
|
165
|
+
image_processor
|
166
|
+
end
|
167
|
+
end
|
168
|
+
end
|
169
|
+
end
|
@@ -0,0 +1,63 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
class BaseImageProcessor < ImageProcessingMixin
|
17
|
+
def initialize(**kwargs)
|
18
|
+
super(**kwargs)
|
19
|
+
end
|
20
|
+
|
21
|
+
def call(images, **kwargs)
|
22
|
+
preprocess(images, **kwargs)
|
23
|
+
end
|
24
|
+
|
25
|
+
def preprocess(images, **kwargs)
|
26
|
+
raise NotImplementedError, "Each image processor must implement its own preprocess method"
|
27
|
+
end
|
28
|
+
|
29
|
+
def rescale(
|
30
|
+
image,
|
31
|
+
scale,
|
32
|
+
data_format: nil,
|
33
|
+
input_data_format: nil,
|
34
|
+
**kwargs
|
35
|
+
)
|
36
|
+
ImageTransforms.rescale(image, scale, data_format: data_format, input_data_format: input_data_format, **kwargs)
|
37
|
+
end
|
38
|
+
|
39
|
+
def normalize(
|
40
|
+
image,
|
41
|
+
mean,
|
42
|
+
std,
|
43
|
+
data_format: nil,
|
44
|
+
input_data_format: nil,
|
45
|
+
**kwargs
|
46
|
+
)
|
47
|
+
ImageTransforms.normalize(
|
48
|
+
image, mean, std, data_format: data_format, input_data_format: input_data_format, **kwargs
|
49
|
+
)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
module ImageProcessingUtils
|
54
|
+
def self.get_size_dict(size)
|
55
|
+
if !size.is_a?(Hash)
|
56
|
+
size_dict = {height: size, width: size}
|
57
|
+
else
|
58
|
+
size_dict = size
|
59
|
+
end
|
60
|
+
size_dict
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module ImageTransforms
|
17
|
+
def self.to_channel_dimension_format(
|
18
|
+
image,
|
19
|
+
channel_dim,
|
20
|
+
input_channel_dim: nil
|
21
|
+
)
|
22
|
+
if !image.is_a?(Numo::NArray)
|
23
|
+
raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
|
24
|
+
end
|
25
|
+
|
26
|
+
if input_channel_dim.nil?
|
27
|
+
input_channel_dim = infer_channel_dimension_format(image)
|
28
|
+
end
|
29
|
+
|
30
|
+
target_channel_dim = ChannelDimension.new(channel_dim).to_s
|
31
|
+
if input_channel_dim == target_channel_dim
|
32
|
+
return image
|
33
|
+
end
|
34
|
+
|
35
|
+
if target_channel_dim == ChannelDimension::FIRST
|
36
|
+
image = image.transpose(2, 0, 1)
|
37
|
+
elsif target_channel_dim == ChannelDimension::LAST
|
38
|
+
image = image.transpose(1, 2, 0)
|
39
|
+
else
|
40
|
+
raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}"
|
41
|
+
end
|
42
|
+
|
43
|
+
image
|
44
|
+
end
|
45
|
+
|
46
|
+
def self.rescale(
|
47
|
+
image,
|
48
|
+
scale,
|
49
|
+
data_format: nil,
|
50
|
+
dtype: Numo::SFloat,
|
51
|
+
input_data_format: nil
|
52
|
+
)
|
53
|
+
if !image.is_a?(Numo::NArray)
|
54
|
+
raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
|
55
|
+
end
|
56
|
+
|
57
|
+
rescaled_image = image * scale
|
58
|
+
if !data_format.nil?
|
59
|
+
rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
|
60
|
+
end
|
61
|
+
|
62
|
+
rescaled_image = rescaled_image.cast_to(dtype)
|
63
|
+
|
64
|
+
rescaled_image
|
65
|
+
end
|
66
|
+
|
67
|
+
def self.resize(
|
68
|
+
image,
|
69
|
+
size,
|
70
|
+
resample: nil,
|
71
|
+
reducing_gap: nil,
|
72
|
+
data_format: nil,
|
73
|
+
return_numpy: true,
|
74
|
+
input_data_format: nil
|
75
|
+
)
|
76
|
+
resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR
|
77
|
+
|
78
|
+
if size.length != 2
|
79
|
+
raise ArgumentError, "size must have 2 elements"
|
80
|
+
end
|
81
|
+
|
82
|
+
# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
|
83
|
+
# The resized image from PIL will always have channels last, so find the input format first.
|
84
|
+
if input_data_format.nil?
|
85
|
+
input_data_format = ImageUtils.infer_channel_dimension_format(image)
|
86
|
+
end
|
87
|
+
data_format = data_format.nil? ? input_data_format : data_format
|
88
|
+
|
89
|
+
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
|
90
|
+
# the pillow library to resize the image and then convert back to numpy
|
91
|
+
do_rescale = false
|
92
|
+
if !image.is_a?(Vips::Image)
|
93
|
+
do_rescale = _rescale_for_pil_conversion(image)
|
94
|
+
image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format)
|
95
|
+
end
|
96
|
+
height, width = size
|
97
|
+
# TODO support resample
|
98
|
+
resized_image = image.thumbnail_image(width, height: height, size: :force)
|
99
|
+
|
100
|
+
if return_numpy
|
101
|
+
resized_image = ImageUtils.to_numo_array(resized_image)
|
102
|
+
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
103
|
+
# so we need to add it back if necessary.
|
104
|
+
resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image
|
105
|
+
# The image is always in channels last format after converting from a PIL image
|
106
|
+
resized_image = to_channel_dimension_format(
|
107
|
+
resized_image, data_format, input_channel_dim: ChannelDimension::LAST
|
108
|
+
)
|
109
|
+
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
|
110
|
+
# rescale it back to the original range.
|
111
|
+
resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image
|
112
|
+
end
|
113
|
+
resized_image
|
114
|
+
end
|
115
|
+
|
116
|
+
def self.normalize(
|
117
|
+
image,
|
118
|
+
mean,
|
119
|
+
std,
|
120
|
+
data_format: nil,
|
121
|
+
input_data_format: nil
|
122
|
+
)
|
123
|
+
if !image.is_a?(Numo::NArray)
|
124
|
+
raise ArgumentError, "image must be a numpy array"
|
125
|
+
end
|
126
|
+
|
127
|
+
if input_data_format.nil?
|
128
|
+
input_data_format = infer_channel_dimension_format(image)
|
129
|
+
end
|
130
|
+
|
131
|
+
channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format)
|
132
|
+
num_channels = image.shape[channel_axis]
|
133
|
+
|
134
|
+
# We cast to float32 to avoid errors that can occur when subtracting uint8 values.
|
135
|
+
# We preserve the original dtype if it is a float type to prevent upcasting float16.
|
136
|
+
if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat)
|
137
|
+
image = image.cast_to(Numo::SFloat)
|
138
|
+
end
|
139
|
+
|
140
|
+
if mean.is_a?(Enumerable)
|
141
|
+
if mean.length != num_channels
|
142
|
+
raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}"
|
143
|
+
end
|
144
|
+
else
|
145
|
+
mean = [mean] * num_channels
|
146
|
+
end
|
147
|
+
mean = Numo::DFloat.cast(mean)
|
148
|
+
|
149
|
+
if std.is_a?(Enumerable)
|
150
|
+
if std.length != num_channels
|
151
|
+
raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}"
|
152
|
+
end
|
153
|
+
else
|
154
|
+
std = [std] * num_channels
|
155
|
+
end
|
156
|
+
std = Numo::DFloat.cast(std)
|
157
|
+
|
158
|
+
if input_data_format == ChannelDimension::LAST
|
159
|
+
image = (image - mean) / std
|
160
|
+
else
|
161
|
+
image = ((image.transpose - mean) / std).transpose
|
162
|
+
end
|
163
|
+
|
164
|
+
image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image
|
165
|
+
image
|
166
|
+
end
|
167
|
+
|
168
|
+
def self.to_pil_image(
|
169
|
+
image,
|
170
|
+
do_rescale: nil,
|
171
|
+
input_data_format: nil
|
172
|
+
)
|
173
|
+
if image.is_a?(Vips::Image)
|
174
|
+
return image
|
175
|
+
end
|
176
|
+
|
177
|
+
# Convert all tensors to numo arrays before converting to Vips image
|
178
|
+
if !image.is_a?(Numo::NArray)
|
179
|
+
raise ArgumentError, "Input image type not supported: #{image.class.name}"
|
180
|
+
end
|
181
|
+
|
182
|
+
# If the channel has been moved to first dim, we put it back at the end.
|
183
|
+
image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format)
|
184
|
+
|
185
|
+
# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
|
186
|
+
# image = image.shape[-1] == 1 ? image.squeeze(-1) : image
|
187
|
+
|
188
|
+
# Rescale the image to be between 0 and 255 if needed.
|
189
|
+
do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale
|
190
|
+
|
191
|
+
if do_rescale
|
192
|
+
image = rescale(image, 255)
|
193
|
+
end
|
194
|
+
|
195
|
+
image = image.cast_to(Numo::UInt8)
|
196
|
+
Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar)
|
197
|
+
end
|
198
|
+
|
199
|
+
def self._rescale_for_pil_conversion(image)
|
200
|
+
if image.is_a?(Numo::UInt8)
|
201
|
+
do_rescale = false
|
202
|
+
else
|
203
|
+
raise Todo
|
204
|
+
end
|
205
|
+
do_rescale
|
206
|
+
end
|
207
|
+
end
|
208
|
+
end
|