transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,94 @@
1
+ module Transformers
2
+ module HfHub
3
+ class HfHubHTTPError < Error
4
+ def initialize(message, response = nil)
5
+ super(message)
6
+ end
7
+ end
8
+
9
+ class RepositoryNotFoundError < HfHubHTTPError; end
10
+
11
+ class GatedRepoError < RepositoryNotFoundError; end
12
+
13
+ class DisabledRepoError < HfHubHTTPError; end
14
+
15
+ class RevisionNotFoundError < HfHubHTTPError; end
16
+
17
+ class EntryNotFoundError < HfHubHTTPError; end
18
+
19
+ class LocalEntryNotFoundError < EntryNotFoundError; end
20
+
21
+ class BadRequestError < HfHubHTTPError; end
22
+
23
+ class << self
24
+ def hf_raise_for_status(response, endpoint_name: nil)
25
+ begin
26
+ response.value unless response.is_a?(Net::HTTPRedirection)
27
+ rescue
28
+ error_code = response["X-Error-Code"]
29
+ error_message = response["X-Error-Message"]
30
+
31
+ if error_code == "RevisionNotFound"
32
+ message = "#{response.code} Client Error." + "\n\n" + "Revision Not Found for url: #{response.uri}."
33
+ raise RevisionNotFoundError.new(message, response)
34
+
35
+ elsif error_code == "EntryNotFound"
36
+ message = "#{response.code} Client Error." + "\n\n" + "Entry Not Found for url: #{response.uri}."
37
+ raise EntryNotFoundError.new(message, response)
38
+
39
+ elsif error_code == "GatedRepo"
40
+ message = (
41
+ "#{response.code} Client Error." + "\n\n" + "Cannot access gated repo for url #{response.uri}."
42
+ )
43
+ raise GatedRepoError.new(message, response)
44
+
45
+ elsif error_message == "Access to this resource is disabled."
46
+ message = (
47
+ "#{response.code} Client Error." +
48
+ "\n\n" +
49
+ "Cannot access repository for url #{response.uri}." +
50
+ "\n" +
51
+ "Access to this resource is disabled."
52
+ )
53
+ raise DisabledRepoError.new(message, response)
54
+
55
+ elsif error_code == "RepoNotFound"
56
+ # 401 is misleading as it is returned for:
57
+ # - private and gated repos if user is not authenticated
58
+ # - missing repos
59
+ # => for now, we process them as `RepoNotFound` anyway.
60
+ # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
61
+ message = (
62
+ "#{response.code} Client Error." +
63
+ "\n\n" +
64
+ "Repository Not Found for url: #{response.uri}." +
65
+ "\nPlease make sure you specified the correct `repo_id` and" +
66
+ " `repo_type`.\nIf you are trying to access a private or gated repo," +
67
+ " make sure you are authenticated."
68
+ )
69
+ raise RepositoryNotFoundError.new(message, response)
70
+
71
+ elsif response.code.to_i == 400
72
+ message = (
73
+ !endpoint_name.nil? ? "\n\nBad request for #{endpoint_name} endpoint:" : "\n\nBad request:"
74
+ )
75
+ raise BadRequestError.new(message, response)
76
+
77
+ elsif response.code.to_i == 403
78
+ message = (
79
+ "\n\n{response.code} Forbidden: #{error_message}." +
80
+ "\nCannot access content at: #{response.uri}." +
81
+ "\nIf you are trying to create or update content, " +
82
+ "make sure you have a token with the `write` role."
83
+ )
84
+ raise HfHubHTTPError.new(message, response)
85
+ end
86
+
87
+ # Convert `HTTPError` into a `HfHubHTTPError` to display request information
88
+ # as well (request id and/or server error message)
89
+ raise HfHubHTTPError.new(e.to_s, response)
90
+ end
91
+ end
92
+ end
93
+ end
94
+ end
@@ -0,0 +1,109 @@
1
+ # Copyright 2022-present, the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module HfHub
17
+ class << self
18
+ def build_hf_headers(
19
+ token: nil,
20
+ is_write_action: false,
21
+ library_name: nil,
22
+ library_version: nil,
23
+ user_agent: nil,
24
+ headers: nil
25
+ )
26
+ # Get auth token to send
27
+ token_to_send = get_token_to_send(token)
28
+ _validate_token_to_send(token_to_send, is_write_action)
29
+
30
+ # Combine headers
31
+ hf_headers = {
32
+ "user-agent" => _http_user_agent(
33
+ library_name: library_name,
34
+ library_version: library_version,
35
+ user_agent: user_agent
36
+ )
37
+ }
38
+ if !token_to_send.nil?
39
+ hf_headers["authorization"] = "Bearer #{token_to_send}"
40
+ end
41
+ if headers
42
+ hf_headers.merge!(headers)
43
+ end
44
+ hf_headers
45
+ end
46
+
47
+ def get_token_to_send(token)
48
+ # Case token is explicitly provided
49
+ if token.is_a?(String)
50
+ return token
51
+ end
52
+
53
+ # Case token is explicitly forbidden
54
+ if token == false
55
+ return nil
56
+ end
57
+
58
+ # Token is not provided: we get it from local cache
59
+ cached_token = nil # get_token
60
+
61
+ # Case token is explicitly required
62
+ if token == true
63
+ if cached_token.nil?
64
+ raise LocalTokenNotFoundError,
65
+ "Token is required (`token: true`), but no token found. You" +
66
+ " need to provide a token or be logged in to Hugging Face with" +
67
+ " `huggingface-cli login` or `huggingface_hub.login`. See" +
68
+ " https://huggingface.co/settings/tokens."
69
+ end
70
+ return cached_token
71
+ end
72
+
73
+ # Case implicit use of the token is forbidden by env variable
74
+ if HF_HUB_DISABLE_IMPLICIT_TOKEN
75
+ return nil
76
+ end
77
+
78
+ # Otherwise: we use the cached token as the user has not explicitly forbidden it
79
+ cached_token
80
+ end
81
+
82
+ def _validate_token_to_send(token, is_write_action)
83
+ if is_write_action
84
+ if token.nil?
85
+ raise ArgumentError,
86
+ "Token is required (write-access action) but no token found. You need" +
87
+ " to provide a token or be logged in to Hugging Face with" +
88
+ " `huggingface-cli login` or `huggingface_hub.login`. See" +
89
+ " https://huggingface.co/settings/tokens."
90
+ end
91
+ end
92
+ end
93
+
94
+ def _http_user_agent(
95
+ library_name: nil,
96
+ library_version: nil,
97
+ user_agent: nil
98
+ )
99
+ if !library_name.nil?
100
+ ua = "#{library_name}/#{library_version}"
101
+ else
102
+ ua = "unknown/None"
103
+ end
104
+ ua += "; ruby/#{RUBY_VERSION.to_f}"
105
+ ua
106
+ end
107
+ end
108
+ end
109
+ end
@@ -0,0 +1,169 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class ImageProcessingMixin
17
+ def self.from_pretrained(
18
+ pretrained_model_name_or_path,
19
+ cache_dir: nil,
20
+ force_download: false,
21
+ local_files_only: false,
22
+ token: nil,
23
+ revision: "main",
24
+ **kwargs
25
+ )
26
+ kwargs[:cache_dir] = cache_dir
27
+ kwargs[:force_download] = force_download
28
+ kwargs[:local_files_only] = local_files_only
29
+ kwargs[:revision] = revision
30
+
31
+ if !token.nil?
32
+ kwargs[:token] = token
33
+ end
34
+
35
+ image_processor_dict, kwargs = get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
36
+
37
+ from_dict(image_processor_dict, **kwargs)
38
+ end
39
+
40
+ def self.get_image_processor_dict(
41
+ pretrained_model_name_or_path, **kwargs
42
+ )
43
+ cache_dir = kwargs.delete(:cache_dir)
44
+ force_download = kwargs.delete(:force_download) { false }
45
+ resume_download = kwargs.delete(:resume_download)
46
+ proxies = kwargs.delete(:proxies)
47
+ token = kwargs.delete(:token)
48
+ _use_auth_token = kwargs.delete(:use_auth_token)
49
+ local_files_only = kwargs.delete(:local_files_only) { false }
50
+ revision = kwargs.delete(:revision)
51
+ subfolder = kwargs.delete(:subfolder) { "" }
52
+
53
+ from_pipeline = kwargs.delete(:_from_pipeline)
54
+ from_auto_class = kwargs.delete(:_from_auto) { false }
55
+
56
+ user_agent = {file_type: "image processor", from_auto_class: from_auto_class}
57
+ if !from_pipeline.nil?
58
+ user_agent[:using_pipeline] = from_pipeline
59
+ end
60
+
61
+ if Utils::Hub.is_offline_mode && !local_files_only
62
+ Transformers.logger.info("Offline mode: forcing local_files_only: true")
63
+ local_files_only = true
64
+ end
65
+
66
+ pretrained_model_name_or_path = pretrained_model_name_or_path.to_s
67
+ is_local = Dir.exist?(pretrained_model_name_or_path)
68
+ if Dir.exist?(pretrained_model_name_or_path)
69
+ image_processor_file = File.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
70
+ end
71
+ if File.exist?(pretrained_model_name_or_path)
72
+ resolved_image_processor_file = pretrained_model_name_or_path
73
+ is_local = true
74
+ elsif Utils::Hub.is_remote_url(pretrained_model_name_or_path)
75
+ raise Todo
76
+ else
77
+ image_processor_file = IMAGE_PROCESSOR_NAME
78
+ begin
79
+ # Load from local folder or from cache or download from model Hub and cache
80
+ resolved_image_processor_file = Utils::Hub.cached_file(
81
+ pretrained_model_name_or_path,
82
+ image_processor_file,
83
+ cache_dir: cache_dir,
84
+ force_download: force_download,
85
+ proxies: proxies,
86
+ resume_download: resume_download,
87
+ local_files_only: local_files_only,
88
+ token: token,
89
+ user_agent: user_agent,
90
+ revision: revision,
91
+ subfolder: subfolder
92
+ )
93
+ rescue EnvironmentError
94
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
95
+ # the original exception.
96
+ raise
97
+ rescue
98
+ # For any other exception, we throw a generic error.
99
+ raise EnvironmentError,
100
+ "Can't load image processor for '#{pretrained_model_name_or_path}'. If you were trying to load" +
101
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" +
102
+ " same name. Otherwise, make sure '#{pretrained_model_name_or_path}' is the correct path to a" +
103
+ " directory containing a #{IMAGE_PROCESSOR_NAME} file"
104
+ end
105
+ end
106
+
107
+ begin
108
+ image_processor_dict = JSON.load_file(resolved_image_processor_file).transform_keys(&:to_sym)
109
+ rescue JSON::ParserError
110
+ raise EnvironmentError,
111
+ "It looks like the config file at '#{resolved_image_processor_file}' is not a valid JSON file."
112
+ end
113
+
114
+ if is_local
115
+ Transformers.logger.info("loading configuration file #{resolved_image_processor_file}")
116
+ else
117
+ Transformers.logger.info(
118
+ "loading configuration file #{image_processor_file} from cache at #{resolved_image_processor_file}"
119
+ )
120
+ end
121
+
122
+ if !is_local
123
+ if image_processor_dict.include?("auto_map")
124
+ raise Todo
125
+ end
126
+ if image_processor_dict.include?("custom_pipelines")
127
+ raise Todo
128
+ end
129
+ end
130
+ [image_processor_dict, kwargs]
131
+ end
132
+
133
+ def self.from_dict(image_processor_dict, **kwargs)
134
+ image_processor_dict = image_processor_dict.dup
135
+ return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }
136
+
137
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
138
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
139
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
140
+ if kwargs.include?(:size) && image_processor_dict.include?(:size)
141
+ image_processor_dict[:size] = kwargs.delete(:size)
142
+ end
143
+ if kwargs.include?(:crop_size) && image_processor_dict.include?(:crop_size)
144
+ image_processor_dict[:crop_size] = kwargs.delete(:crop_size)
145
+ end
146
+
147
+ image_processor = new(**image_processor_dict)
148
+
149
+ # Update image_processor with kwargs if needed
150
+ to_remove = []
151
+ kwargs.each do |key, value|
152
+ if image_processor.instance_variable_defined?("@#{key}")
153
+ image_processor.instance_variable_set("@#{key}", value)
154
+ to_remove << key
155
+ end
156
+ end
157
+ to_remove.each do |key|
158
+ kwargs.delete(key)
159
+ end
160
+
161
+ Transformers.logger.info("Image processor #{image_processor}")
162
+ if return_unused_kwargs
163
+ [image_processor, kwargs]
164
+ else
165
+ image_processor
166
+ end
167
+ end
168
+ end
169
+ end
@@ -0,0 +1,63 @@
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class BaseImageProcessor < ImageProcessingMixin
17
+ def initialize(**kwargs)
18
+ super(**kwargs)
19
+ end
20
+
21
+ def call(images, **kwargs)
22
+ preprocess(images, **kwargs)
23
+ end
24
+
25
+ def preprocess(images, **kwargs)
26
+ raise NotImplementedError, "Each image processor must implement its own preprocess method"
27
+ end
28
+
29
+ def rescale(
30
+ image,
31
+ scale,
32
+ data_format: nil,
33
+ input_data_format: nil,
34
+ **kwargs
35
+ )
36
+ ImageTransforms.rescale(image, scale, data_format: data_format, input_data_format: input_data_format, **kwargs)
37
+ end
38
+
39
+ def normalize(
40
+ image,
41
+ mean,
42
+ std,
43
+ data_format: nil,
44
+ input_data_format: nil,
45
+ **kwargs
46
+ )
47
+ ImageTransforms.normalize(
48
+ image, mean, std, data_format: data_format, input_data_format: input_data_format, **kwargs
49
+ )
50
+ end
51
+ end
52
+
53
+ module ImageProcessingUtils
54
+ def self.get_size_dict(size)
55
+ if !size.is_a?(Hash)
56
+ size_dict = {height: size, width: size}
57
+ else
58
+ size_dict = size
59
+ end
60
+ size_dict
61
+ end
62
+ end
63
+ end
@@ -0,0 +1,208 @@
1
+ # Copyright 2022 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module ImageTransforms
17
+ def self.to_channel_dimension_format(
18
+ image,
19
+ channel_dim,
20
+ input_channel_dim: nil
21
+ )
22
+ if !image.is_a?(Numo::NArray)
23
+ raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
24
+ end
25
+
26
+ if input_channel_dim.nil?
27
+ input_channel_dim = infer_channel_dimension_format(image)
28
+ end
29
+
30
+ target_channel_dim = ChannelDimension.new(channel_dim).to_s
31
+ if input_channel_dim == target_channel_dim
32
+ return image
33
+ end
34
+
35
+ if target_channel_dim == ChannelDimension::FIRST
36
+ image = image.transpose(2, 0, 1)
37
+ elsif target_channel_dim == ChannelDimension::LAST
38
+ image = image.transpose(1, 2, 0)
39
+ else
40
+ raise ArgumentError, "Unsupported channel dimension format: #{channel_dim}"
41
+ end
42
+
43
+ image
44
+ end
45
+
46
+ def self.rescale(
47
+ image,
48
+ scale,
49
+ data_format: nil,
50
+ dtype: Numo::SFloat,
51
+ input_data_format: nil
52
+ )
53
+ if !image.is_a?(Numo::NArray)
54
+ raise ArgumentError, "Input image must be of type Numo::NArray, got #{image.class.name}"
55
+ end
56
+
57
+ rescaled_image = image * scale
58
+ if !data_format.nil?
59
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
60
+ end
61
+
62
+ rescaled_image = rescaled_image.cast_to(dtype)
63
+
64
+ rescaled_image
65
+ end
66
+
67
+ def self.resize(
68
+ image,
69
+ size,
70
+ resample: nil,
71
+ reducing_gap: nil,
72
+ data_format: nil,
73
+ return_numpy: true,
74
+ input_data_format: nil
75
+ )
76
+ resample = !resample.nil? ? resample : nil # PILImageResampling.BILINEAR
77
+
78
+ if size.length != 2
79
+ raise ArgumentError, "size must have 2 elements"
80
+ end
81
+
82
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
83
+ # The resized image from PIL will always have channels last, so find the input format first.
84
+ if input_data_format.nil?
85
+ input_data_format = ImageUtils.infer_channel_dimension_format(image)
86
+ end
87
+ data_format = data_format.nil? ? input_data_format : data_format
88
+
89
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
90
+ # the pillow library to resize the image and then convert back to numpy
91
+ do_rescale = false
92
+ if !image.is_a?(Vips::Image)
93
+ do_rescale = _rescale_for_pil_conversion(image)
94
+ image = to_pil_image(image, do_rescale: do_rescale, input_data_format: input_data_format)
95
+ end
96
+ height, width = size
97
+ # TODO support resample
98
+ resized_image = image.thumbnail_image(width, height: height, size: :force)
99
+
100
+ if return_numpy
101
+ resized_image = ImageUtils.to_numo_array(resized_image)
102
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
103
+ # so we need to add it back if necessary.
104
+ resized_image = resized_image.ndim == 2 ? resized_image.expand_dims(-1) : resized_image
105
+ # The image is always in channels last format after converting from a PIL image
106
+ resized_image = to_channel_dimension_format(
107
+ resized_image, data_format, input_channel_dim: ChannelDimension::LAST
108
+ )
109
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
110
+ # rescale it back to the original range.
111
+ resized_image = do_rescale ? rescale(resized_image, 1 / 255.0) : resized_image
112
+ end
113
+ resized_image
114
+ end
115
+
116
+ def self.normalize(
117
+ image,
118
+ mean,
119
+ std,
120
+ data_format: nil,
121
+ input_data_format: nil
122
+ )
123
+ if !image.is_a?(Numo::NArray)
124
+ raise ArgumentError, "image must be a numpy array"
125
+ end
126
+
127
+ if input_data_format.nil?
128
+ input_data_format = infer_channel_dimension_format(image)
129
+ end
130
+
131
+ channel_axis = ImageUtils.get_channel_dimension_axis(image, input_data_format: input_data_format)
132
+ num_channels = image.shape[channel_axis]
133
+
134
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
135
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
136
+ if !image.is_a?(Numo::SFloat) && !image.is_a?(Numo::DFloat)
137
+ image = image.cast_to(Numo::SFloat)
138
+ end
139
+
140
+ if mean.is_a?(Enumerable)
141
+ if mean.length != num_channels
142
+ raise ArgumentError, "mean must have #{num_channels} elements if it is an iterable, got #{mean.length}"
143
+ end
144
+ else
145
+ mean = [mean] * num_channels
146
+ end
147
+ mean = Numo::DFloat.cast(mean)
148
+
149
+ if std.is_a?(Enumerable)
150
+ if std.length != num_channels
151
+ raise ArgumentError, "std must have #{num_channels} elements if it is an iterable, got #{std.length}"
152
+ end
153
+ else
154
+ std = [std] * num_channels
155
+ end
156
+ std = Numo::DFloat.cast(std)
157
+
158
+ if input_data_format == ChannelDimension::LAST
159
+ image = (image - mean) / std
160
+ else
161
+ image = ((image.transpose - mean) / std).transpose
162
+ end
163
+
164
+ image = !data_format.nil? ? to_channel_dimension_format(image, data_format, input_data_format) : image
165
+ image
166
+ end
167
+
168
+ def self.to_pil_image(
169
+ image,
170
+ do_rescale: nil,
171
+ input_data_format: nil
172
+ )
173
+ if image.is_a?(Vips::Image)
174
+ return image
175
+ end
176
+
177
+ # Convert all tensors to numo arrays before converting to Vips image
178
+ if !image.is_a?(Numo::NArray)
179
+ raise ArgumentError, "Input image type not supported: #{image.class.name}"
180
+ end
181
+
182
+ # If the channel has been moved to first dim, we put it back at the end.
183
+ image = to_channel_dimension_format(image, ChannelDimension::LAST, input_channel_dim: input_data_format)
184
+
185
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
186
+ # image = image.shape[-1] == 1 ? image.squeeze(-1) : image
187
+
188
+ # Rescale the image to be between 0 and 255 if needed.
189
+ do_rescale = do_rescale.nil? ? _rescale_for_pil_conversion(image) : do_rescale
190
+
191
+ if do_rescale
192
+ image = rescale(image, 255)
193
+ end
194
+
195
+ image = image.cast_to(Numo::UInt8)
196
+ Vips::Image.new_from_memory(image.to_binary, image.shape[1], image.shape[0], image.shape[2], :uchar)
197
+ end
198
+
199
+ def self._rescale_for_pil_conversion(image)
200
+ if image.is_a?(Numo::UInt8)
201
+ do_rescale = false
202
+ else
203
+ raise Todo
204
+ end
205
+ do_rescale
206
+ end
207
+ end
208
+ end