transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,94 @@
1
+ module Transformers
2
+ module HfHub
3
+ class HfHubHTTPError < Error
4
+ def initialize(message, response = nil)
5
+ super(message)
6
+ end
7
+ end
8
+
9
+ class RepositoryNotFoundError < HfHubHTTPError; end
10
+
11
+ class GatedRepoError < RepositoryNotFoundError; end
12
+
13
+ class DisabledRepoError < HfHubHTTPError; end
14
+
15
+ class RevisionNotFoundError < HfHubHTTPError; end
16
+
17
+ class EntryNotFoundError < HfHubHTTPError; end
18
+
19
+ class LocalEntryNotFoundError < EntryNotFoundError; end
20
+
21
+ class BadRequestError < HfHubHTTPError; end
22
+
23
+ class << self
24
+ def hf_raise_for_status(response, endpoint_name: nil)
25
+ begin
26
+ response.value unless response.is_a?(Net::HTTPRedirection)
27
+ rescue
28
+ error_code = response["X-Error-Code"]
29
+ error_message = response["X-Error-Message"]
30
+
31
+ if error_code == "RevisionNotFound"
32
+ message = "#{response.code} Client Error." + "\n\n" + "Revision Not Found for url: #{response.uri}."
33
+ raise RevisionNotFoundError.new(message, response)
34
+
35
+ elsif error_code == "EntryNotFound"
36
+ message = "#{response.code} Client Error." + "\n\n" + "Entry Not Found for url: #{response.uri}."
37
+ raise EntryNotFoundError.new(message, response)
38
+
39
+ elsif error_code == "GatedRepo"
40
+ message = (
41
+ "#{response.code} Client Error." + "\n\n" + "Cannot access gated repo for url #{response.uri}."
42
+ )
43
+ raise GatedRepoError.new(message, response)
44
+
45
+ elsif error_message == "Access to this resource is disabled."
46
+ message = (
47
+ "#{response.code} Client Error." +
48
+ "\n\n" +
49
+ "Cannot access repository for url #{response.uri}." +
50
+ "\n" +
51
+ "Access to this resource is disabled."
52
+ )
53
+ raise DisabledRepoError.new(message, response)
54
+
55
+ elsif error_code == "RepoNotFound"
56
+ # 401 is misleading as it is returned for:
57
+ # - private and gated repos if user is not authenticated
58
+ # - missing repos
59
+ # => for now, we process them as `RepoNotFound` anyway.
60
+ # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
61
+ message = (
62
+ "#{response.code} Client Error." +
63
+ "\n\n" +
64
+ "Repository Not Found for url: #{response.uri}." +
65
+ "\nPlease make sure you specified the correct `repo_id` and" +
66
+ " `repo_type`.\nIf you are trying to access a private or gated repo," +
67
+ " make sure you are authenticated."
68
+ )
69
+ raise RepositoryNotFoundError.new(message, response)
70
+
71
+ elsif response.code.to_i == 400
72
+ message = (
73
+ !endpoint_name.nil? ? "\n\nBad request for #{endpoint_name} endpoint:" : "\n\nBad request:"
74
+ )
75
+ raise BadRequestError.new(message, response)
76
+
77
+ elsif response.code.to_i == 403
78
+ message = (
79
+ "\n\n{response.code} Forbidden: #{error_message}." +
80
+ "\nCannot access content at: #{response.uri}." +
81
+ "\nIf you are trying to create or update content, " +
82
+ "make sure you have a token with the `write` role."
83
+ )
84
+ raise HfHubHTTPError.new(message, response)
85
+ end
86
+
87
+ # Convert `HTTPError` into a `HfHubHTTPError` to display request information
88
+ # as well (request id and/or server error message)
89
+ raise HfHubHTTPError.new(e.to_s, response)
90
+ end
91
+ end
92
+ end
93
+ end
94
+ end
@@ -0,0 +1,109 @@
1
+ # Copyright 2022-present, the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module HfHub
17
+ class << self
18
+ def build_hf_headers(
19
+ token: nil,
20
+ is_write_action: false,
21
+ library_name: nil,
22
+ library_version: nil,
23
+ user_agent: nil,
24
+ headers: nil
25
+ )
26
+ # Get auth token to send
27
+ token_to_send = get_token_to_send(token)
28
+ _validate_token_to_send(token_to_send, is_write_action)
29
+
30
+ # Combine headers
31
+ hf_headers = {
32
+ "user-agent" => _http_user_agent(
33
+ library_name: library_name,
34
+ library_version: library_version,
35
+ user_agent: user_agent
36
+ )
37
+ }
38
+ if !token_to_send.nil?
39
+ hf_headers["authorization"] = "Bearer #{token_to_send}"
40
+ end
41
+ if headers
42
+ hf_headers.merge!(headers)
43
+ end
44
+ hf_headers
45
+ end
46
+
47
+ def get_token_to_send(token)
48
+ # Case token is explicitly provided
49
+ if token.is_a?(String)
50
+ return token
51
+ end
52
+
53
+ # Case token is explicitly forbidden
54
+ if token == false
55
+ return nil
56
+ end
57
+
58
+ # Token is not provided: we get it from local cache
59
+ cached_token = nil # get_token
60
+
61
+ # Case token is explicitly required
62
+ if token == true
63
+ if cached_token.nil?
64
+ raise LocalTokenNotFoundError,
65
+ "Token is required (`token: true`), but no token found. You" +
66
+ " need to provide a token or be logged in to Hugging Face with" +
67
+ " `huggingface-cli login` or `huggingface_hub.login`. See" +
68
+ " https://huggingface.co/settings/tokens."
69
+ end
70
+ return cached_token
71
+ end
72
+
73
+ # Case implicit use of the token is forbidden by env variable
74
+ if HF_HUB_DISABLE_IMPLICIT_TOKEN
75
+ return nil
76
+ end
77
+
78
+ # Otherwise: we use the cached token as the user has not explicitly forbidden it
79
+ cached_token
80
+ end
81
+
82
+ def _validate_token_to_send(token, is_write_action)
83
+ if is_write_action
84
+ if token.nil?
85
+ raise ArgumentError,
86
+ "Token is required (write-access action) but no token found. You need" +
87
+ " to provide a token or be logged in to Hugging Face with" +
88
+ " `huggingface-cli login` or `huggingface_hub.login`. See" +
89
+ " https://huggingface.co/settings/tokens."
90
+ end
91
+ end
92
+ end
93
+
94
+ def _http_user_agent(
95
+ library_name: nil,
96
+ library_version: nil,
97
+ user_agent: nil
98
+ )
99
+ if !library_name.nil?
100
+ ua = "#{library_name}/#{library_version}"
101
+ else
102
+ ua = "unknown/None"
103
+ end
104
+ ua += "; ruby/#{RUBY_VERSION.to_f}"
105
+ ua
106
+ end
107
+ end
108
+ end
109
+ end
@@ -0,0 +1,169 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class ImageProcessingMixin
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ cache_dir: nil,
20
+ force_download: false,
21
+ local_files_only: false,
22
+ token: nil,
23
+ revision: "main",
24
+ **kwargs
25
+ )
26
+ kwargs[:cache_dir] = cache_dir
27
+ kwargs[:force_download] = force_download
28
+ kwargs[:local_files_only] = local_files_only
29
+ kwargs[:revision] = revision
30
+
31
+ if !token.nil?
32
+ kwargs[:token] = token
33
+ end
34
+
35
+ image_processor_dict, kwargs = get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
36
+
37
+ from_dict(image_processor_dict, **kwargs)
38
+ end
39
+
40
+ def self.get_image_processor_dict(
41
+ pretrained_model_name_or_path, **kwargs
42
+ )
43
+ cache_dir = kwargs.delete(:cache_dir)
44
+ force_download = kwargs.delete(:force_download) { false }
45
+ resume_download = kwargs.delete(:resume_download)
46
+ proxies = kwargs.delete(:proxies)
47
+ token = kwargs.delete(:token)
48
+ _use_auth_token = kwargs.delete(:use_auth_token)
49
+ local_files_only = kwargs.delete(:local_files_only) { false }
50
+ revision = kwargs.delete(:revision)
51
+ subfolder = kwargs.delete(:subfolder) { "" }
52
+
53
+ from_pipeline = kwargs.delete(:_from_pipeline)
54
+ from_auto_class = kwargs.delete(:_from_auto) { false }
55
+
56
+ user_agent = {file_type: "image processor", from_auto_class: from_auto_class}
57
+ if !from_pipeline.nil?
58
+ user_agent[:using_pipeline] = from_pipeline
59
+ end
60
+
61
+ if Utils::Hub.is_offline_mode && !local_files_only
62
+ Transformers.logger.info("Offline mode: forcing local_files_only: true")
63
+ local_files_only = true
64
+ end
65
+
66
+ pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
67
+ is_local = Dir.exist?(pretrained_model_name_or_path)
68
+ if Dir.exist?(pretrained_model_name_or_path)
69
+ image_processor_file = File.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
70
+ end
71
+ if File.exist?(pretrained_model_name_or_path)
72
+ resolved_image_processor_file = pretrained_model_name_or_path
73
+ is_local = true
74
+ elsif Utils::Hub.is_remote_url(pretrained_model_name_or_path)
75
+ raise Todo
76
+ else
77
+ image_processor_file = IMAGE_PROCESSOR_NAME
78
+ begin
79
+ # Load from local folder or from cache or download from model Hub and cache
80
+ resolved_image_processor_file = Utils::Hub.cached_file(
81
+ pretrained_model_name_or_path,
82
+ image_processor_file,
83
+ cache_dir: cache_dir,
84
+ force_download: force_download,
85
+ proxies: proxies,
86
+ resume_download: resume_download,
87
+ local_files_only: local_files_only,
88
+ token: token,
89
+ user_agent: user_agent,
90
+ revision: revision,
91
+ subfolder: subfolder
92
+ )
93
+ rescue EnvironmentError
94
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
95
+ # the original exception.
96
+ raise
97
+ rescue
98
+ # For any other exception, we throw a generic error.
99
+ raise EnvironmentError,
100
+ "Can't load image processor for '#{pretrained_model_name_or_path}'. If you were trying to load" +
101
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" +
102
+ " same name. Otherwise, make sure '#{pretrained_model_name_or_path}' is the correct path to a" +
103
+ " directory containing a #{IMAGE_PROCESSOR_NAME} file"
104
+ end
105
+ end
106
+
107
+ begin
108
+ image_processor_dict = JSON.load_file(resolved_image_processor_file).transform_keys(&:to_sym)
109
+ rescue JSON::ParserError
110
+ raise EnvironmentError,
111
+ "It looks like the config file at '#{resolved_image_processor_file}' is not a valid JSON file."
112
+ end
113
+
114
+ if is_local
115
+ Transformers.logger.info("loading configuration file #{resolved_image_processor_file}")
116
+ else
117
+ Transformers.logger.info(
118
+ "loading configuration file #{image_processor_file} from cache at #{resolved_image_processor_file}"
119
+ )
120
+ end
121
+
122
+ if !is_local
123
+ if image_processor_dict.include?("auto_map")
124
+ raise Todo
125
+ end
126
+ if image_processor_dict.include?("custom_pipelines")
127
+ raise Todo
128
+ end
129
+ end
130
+ [image_processor_dict, kwargs]
131
+ end
132
+
133
+ def self.from_dict(image_processor_dict, **kwargs)
134
+ image_processor_dict = image_processor_dict.dup
135
+ return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
136
+
137
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
138
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
139
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
140
+ if kwargs.include?(:size) && image_processor_dict.include?(:size)
141
+ image_processor_dict[:size] = kwargs.delete(:size)
142
+ end
143
+ if kwargs.include?(:crop_size) && image_processor_dict.include?(:crop_size)
144
+ image_processor_dict[:crop_size] = kwargs.delete(:crop_size)
145
+ end
146
+
147
+ image_processor = new(**image_processor_dict)
148
+
149
+ # Update image_processor with kwargs if needed
150
+ to_remove = []
151
+ kwargs.each do |key, value|
152
+ if image_processor.instance_variable_defined?("@#{key}")
153
+ image_processor.instance_variable_set("@#{key}", value)
154
+ to_remove << key
155
+ end
156
+ end
157
+ to_remove.each do |key|
158
+ kwargs.delete(key)
159
+ end
160
+
161
+ Transformers.logger.info("Image processor #{image_processor}")
162
+ if return_unused_kwargs
163
+ [image_processor, kwargs]
164
+ else
165
+ image_processor
166
+ end
167
+ end
168
+ end
169
+ end
@@ -0,0 +1,63 @@
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BaseImageProcessor < ImageProcessingMixin
17
+ def initialize(**kwargs)
18
+ super(**kwargs)
19
+ end
20
+
21
+ def call(images, **kwargs)
22
+ preprocess(images, **kwargs)
23
+ end
24
+
25
+ def preprocess(images, **kwargs)
26
+ raise NotImplementedError, "Each image processor must implement its own preprocess method"
27
+ end
28
+
29
+ def rescale(
30
+ image,
31
+ scale,
32
+ data_format: nil,
33
+ input_data_format: nil,
34
+ **kwargs
35
+ )
36
+ ImageTransforms.rescale(image, scale, data_format: data_format, input_data_format: input_data_format, **kwargs)
37
+ end
38
+
39
+ def normalize(
40
+ image,
41
+ mean,
42
+ std,
43
+ data_format: nil,
44
+ input_data_format: nil,
45
+ **kwargs
46
+ )
47
+ ImageTransforms.normalize(
48
+ image, mean, std, data_format: data_format, input_data_format: input_data_format, **kwargs
49
+ )
50
+ end
51
+ end
52
+
53
+ module ImageProcessingUtils
54
+ def self.get_size_dict(size)
55
+ if !size.is_a?(Hash)
56
+ size_dict = {height: size, width: size}
57
+ else
58
+ size_dict = size
59
+ end
60
+ size_dict
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,208 @@
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module ImageTransforms
17
+ def self.to_channel_dimension_format(
18
+ image,
19
+ channel_dim,
20
+ input_channel_dim: nil
21
+ )
22
+ if !image.is_a?(Numo::NArray)
23
+ raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
24
+ end
25
+
26
+ if input_channel_dim.nil?
27
+ input_channel_dim = infer_channel_dimension_format(image)
28
+ end
29
+
30
+ target_channel_dim = ChannelDimension.new(channel_dim).to_s
31
+ if input_channel_dim == target_channel_dim
32
+ return image
33
+ end
34
+
35
+ if target_channel_dim == ChannelDimension::FIRST
36
+ image = image.transpose(2, 0, 1)
37
+ elsif target_channel_dim == ChannelDimension::LAST
38
+ image = image.transpose(1, 2, 0)
39
+ else
40
+ raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}"
41
+ end
42
+
43
+ image
44
+ end
45
+
46
+ def self.rescale(
47
+ image,
48
+ scale,
49
+ data_format: nil,
50
+ dtype: Numo::SFloat,
51
+ input_data_format: nil
52
+ )
53
+ if !image.is_a?(Numo::NArray)
54
+ raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
55
+ end
56
+
57
+ rescaled_image = image * scale
58
+ if !data_format.nil?
59
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
60
+ end
61
+
62
+ rescaled_image = rescaled_image.cast_to(dtype)
63
+
64
+ rescaled_image
65
+ end
66
+
67
+ def self.resize(
68
+ image,
69
+ size,
70
+ resample: nil,
71
+ reducing_gap: nil,
72
+ data_format: nil,
73
+ return_numpy: true,
74
+ input_data_format: nil
75
+ )
76
+ resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR
77
+
78
+ if size.length != 2
79
+ raise ArgumentError, "size must have 2 elements"
80
+ end
81
+
82
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
83
+ # The resized image from PIL will always have channels last, so find the input format first.
84
+ if input_data_format.nil?
85
+ input_data_format = ImageUtils.infer_channel_dimension_format(image)
86
+ end
87
+ data_format = data_format.nil? ? input_data_format : data_format
88
+
89
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
90
+ # the pillow library to resize the image and then convert back to numpy
91
+ do_rescale = false
92
+ if !image.is_a?(Vips::Image)
93
+ do_rescale = _rescale_for_pil_conversion(image)
94
+ image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format)
95
+ end
96
+ height, width = size
97
+ # TODO support resample
98
+ resized_image = image.thumbnail_image(width, height: height, size: :force)
99
+
100
+ if return_numpy
101
+ resized_image = ImageUtils.to_numo_array(resized_image)
102
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
103
+ # so we need to add it back if necessary.
104
+ resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image
105
+ # The image is always in channels last format after converting from a PIL image
106
+ resized_image = to_channel_dimension_format(
107
+ resized_image, data_format, input_channel_dim: ChannelDimension::LAST
108
+ )
109
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
110
+ # rescale it back to the original range.
111
+ resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image
112
+ end
113
+ resized_image
114
+ end
115
+
116
+ def self.normalize(
117
+ image,
118
+ mean,
119
+ std,
120
+ data_format: nil,
121
+ input_data_format: nil
122
+ )
123
+ if !image.is_a?(Numo::NArray)
124
+ raise ArgumentError, "image must be a numpy array"
125
+ end
126
+
127
+ if input_data_format.nil?
128
+ input_data_format = infer_channel_dimension_format(image)
129
+ end
130
+
131
+ channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format)
132
+ num_channels = image.shape[channel_axis]
133
+
134
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
135
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
136
+ if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat)
137
+ image = image.cast_to(Numo::SFloat)
138
+ end
139
+
140
+ if mean.is_a?(Enumerable)
141
+ if mean.length != num_channels
142
+ raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}"
143
+ end
144
+ else
145
+ mean = [mean] * num_channels
146
+ end
147
+ mean = Numo::DFloat.cast(mean)
148
+
149
+ if std.is_a?(Enumerable)
150
+ if std.length != num_channels
151
+ raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}"
152
+ end
153
+ else
154
+ std = [std] * num_channels
155
+ end
156
+ std = Numo::DFloat.cast(std)
157
+
158
+ if input_data_format == ChannelDimension::LAST
159
+ image = (image - mean) / std
160
+ else
161
+ image = ((image.transpose - mean) / std).transpose
162
+ end
163
+
164
+ image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image
165
+ image
166
+ end
167
+
168
+ def self.to_pil_image(
169
+ image,
170
+ do_rescale: nil,
171
+ input_data_format: nil
172
+ )
173
+ if image.is_a?(Vips::Image)
174
+ return image
175
+ end
176
+
177
+ # Convert all tensors to numo arrays before converting to Vips image
178
+ if !image.is_a?(Numo::NArray)
179
+ raise ArgumentError, "Input image type not supported: #{image.class.name}"
180
+ end
181
+
182
+ # If the channel has been moved to first dim, we put it back at the end.
183
+ image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format)
184
+
185
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
186
+ # image = image.shape[-1] == 1 ? image.squeeze(-1) : image
187
+
188
+ # Rescale the image to be between 0 and 255 if needed.
189
+ do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale
190
+
191
+ if do_rescale
192
+ image = rescale(image, 255)
193
+ end
194
+
195
+ image = image.cast_to(Numo::UInt8)
196
+ Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar)
197
+ end
198
+
199
+ def self._rescale_for_pil_conversion(image)
200
+ if image.is_a?(Numo::UInt8)
201
+ do_rescale = false
202
+ else
203
+ raise Todo
204
+ end
205
+ do_rescale
206
+ end
207
+ end
208
+ end