informers 1.0.3 → 1.1.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,16 +1,65 @@
1
1
  module Informers
2
2
  class PreTrainedTokenizer
3
- attr_reader :sep_token_id
3
+ attr_reader :mask_token, :mask_token_id, :sep_token_id
4
4
 
5
5
  def initialize(tokenizer_json, tokenizer_config)
6
6
  super()
7
7
 
8
+ @tokenizer_config = tokenizer_config
9
+
8
10
  @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
9
11
 
10
- @sep_token = tokenizer_config["sep_token"]
11
- @sep_token_id = @tokenizer.token_to_id(@sep_token)
12
+ # Add added_tokens to model
13
+ @special_tokens = []
14
+ @all_special_ids = []
15
+
16
+ @added_tokens = []
17
+ @tokenizer.added_tokens_decoder.each do |id, token|
18
+ @added_tokens << token
19
+
20
+ if token.special
21
+ @special_tokens << token.content
22
+ @all_special_ids << id
23
+ end
24
+ end
25
+
26
+ # Update additional_special_tokens
27
+ @additional_special_tokens = tokenizer_config["additional_special_tokens"] || []
28
+ @special_tokens.concat(@additional_special_tokens)
29
+
30
+ @mask_token = get_token("mask_token")
31
+ @mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token
32
+
33
+ @sep_token = get_token("sep_token")
34
+ @sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token
12
35
 
13
36
  @model_max_length = tokenizer_config["model_max_length"]
37
+
38
+ # for donut-base-finetuned-docvqa
39
+ if @model_max_length && @model_max_length > (1 << 63)
40
+ @model_max_length = 1 << 63
41
+ end
42
+ end
43
+
44
+ def get_token(*keys)
45
+ keys.each do |key|
46
+ item = @tokenizer_config[key]
47
+ if !item
48
+ next
49
+ end
50
+
51
+ if item.is_a?(Hash)
52
+ if item["__type"] == "AddedToken"
53
+ return item["content"]
54
+ else
55
+ raise Error, "Unknown token: #{item}"
56
+ end
57
+ else
58
+ return item
59
+ end
60
+ end
61
+
62
+ nil
14
63
  end
15
64
 
16
65
  def call(
@@ -76,6 +125,22 @@ module Informers
76
125
  def convert_tokens_to_string(tokens)
77
126
  @tokenizer.decoder.decode(tokens)
78
127
  end
128
+
129
+ def convert_tokens_to_ids(tokens)
130
+ tokens.map { |t| @tokenizer.token_to_id(t) }
131
+ end
132
+
133
+ def id_to_token(id)
134
+ @tokenizer.id_to_token(id)
135
+ end
136
+
137
+ def batch_decode(batch, **decode_args)
138
+ @tokenizer.decode_batch(batch, **decode_args)
139
+ end
140
+
141
+ def padding_side=(side)
142
+ @tokenizer.enable_padding(direction: side)
143
+ end
79
144
  end
80
145
 
81
146
  class BertTokenizer < PreTrainedTokenizer
@@ -91,6 +156,16 @@ module Informers
91
156
  class DistilBertTokenizer < PreTrainedTokenizer
92
157
  end
93
158
 
159
+ class T5Tokenizer < PreTrainedTokenizer
160
+ end
161
+
162
+ class GPT2Tokenizer < PreTrainedTokenizer
163
+ # _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`
164
+ end
165
+
166
+ class BartTokenizer < PreTrainedTokenizer
167
+ end
168
+
94
169
  class RobertaTokenizer < PreTrainedTokenizer
95
170
  end
96
171
 
@@ -100,14 +175,93 @@ module Informers
100
175
  class MPNetTokenizer < PreTrainedTokenizer
101
176
  end
102
177
 
178
+ class CLIPTokenizer < PreTrainedTokenizer
179
+ end
180
+
181
+ class NllbTokenizer < PreTrainedTokenizer
182
+ attr_reader :language_regex, :language_codes, :lang_to_token
183
+
184
+ def initialize(tokenizer_json, tokenizer_config)
185
+ super(tokenizer_json, tokenizer_config)
186
+
187
+ @language_regex = /^[a-z]{3}_[A-Z][a-z]{3}$/
188
+ @language_codes = @special_tokens.filter { |x| @language_regex.match?(x) }
189
+ @lang_to_token = ->(x) { x } # Identity function
190
+ end
191
+
192
+ def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
193
+ Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
194
+ end
195
+ end
196
+
197
+ class M2M100Tokenizer < PreTrainedTokenizer
198
+ attr_reader :language_regex, :language_codes, :lang_to_token
199
+
200
+ def initialize(tokenizer_json, tokenizer_config)
201
+ super(tokenizer_json, tokenizer_config)
202
+
203
+ @language_regex = /^__[a-z]{2,3}__$/
204
+ @language_codes = @special_tokens
205
+ .filter { |x| @language_regex.match?(x) }
206
+ .map { |x| x.slice(2, -2) }
207
+ @lang_to_token = ->(x) { "__#{x}__" }
208
+ end
209
+
210
+ def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
211
+ Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
212
+ end
213
+ end
214
+
215
+ module Utils
216
+ def self._build_translation_inputs(slf, raw_inputs, tokenizer_options, generate_kwargs)
217
+ if !slf.respond_to?(:language_codes) || !slf.language_codes.is_a?(Array)
218
+ raise Error, "Tokenizer must have `language_codes` attribute set and it should be an array of language ids."
219
+ end
220
+ if !slf.respond_to?(:language_regex) || !slf.language_regex.is_a?(Regexp)
221
+ raise Error, "Tokenizer must have `language_regex` attribute set and it should be a regular expression."
222
+ end
223
+ if !slf.respond_to?(:lang_to_token) || !slf.lang_to_token.respond_to?(:call)
224
+ raise Error, "Tokenizer must have `lang_to_token` attribute set and it should be a function."
225
+ end
226
+ src_lang_token = generate_kwargs[:src_lang]
227
+ tgt_lang_token = generate_kwargs[:tgt_lang]
228
+
229
+ if !slf.language_codes.include?(tgt_lang_token)
230
+ raise Error, "Target language code #{tgt_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
231
+ end
232
+
233
+ if !src_lang_token.nil?
234
+ # Check that the source language is valid:
235
+ if !slf.language_codes.include?(src_lang_token)
236
+ raise Error, "Source language code #{src_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
237
+ end
238
+ end
239
+
240
+ # Override the `forced_bos_token_id` to force the correct language
241
+ generate_kwargs["forced_bos_token_id"] = slf.convert_tokens_to_ids([slf.lang_to_token.(tgt_lang_token)])[0]
242
+
243
+ slf.(raw_inputs, **tokenizer_options)
244
+ end
245
+ end
246
+
247
+ class SpeechT5Tokenizer < PreTrainedTokenizer
248
+ end
249
+
103
250
  class AutoTokenizer
104
251
  TOKENIZER_CLASS_MAPPING = {
252
+ "T5Tokenizer" => T5Tokenizer,
105
253
  "BertTokenizer" => BertTokenizer,
106
254
  "DebertaV2Tokenizer" => DebertaV2Tokenizer,
107
255
  "DistilBertTokenizer" => DistilBertTokenizer,
256
+ "BartTokenizer" => BartTokenizer,
108
257
  "RobertaTokenizer" => RobertaTokenizer,
109
258
  "XLMRobertaTokenizer" => XLMRobertaTokenizer,
110
- "MPNetTokenizer" => MPNetTokenizer
259
+ "MPNetTokenizer" => MPNetTokenizer,
260
+ "CLIPTokenizer" => CLIPTokenizer,
261
+ "GPT2Tokenizer" => GPT2Tokenizer,
262
+ "NllbTokenizer" => NllbTokenizer,
263
+ "M2M100Tokenizer" => M2M100Tokenizer,
264
+ "SpeechT5Tokenizer" => SpeechT5Tokenizer
111
265
  }
112
266
 
113
267
  def self.from_pretrained(
@@ -146,7 +300,7 @@ module Informers
146
300
  def self.load_tokenizer(pretrained_model_name_or_path, **options)
147
301
  info = [
148
302
  Utils::Hub.get_model_file(pretrained_model_name_or_path, "tokenizer.json", true, **options),
149
- Utils::Hub.get_model_json(pretrained_model_name_or_path, "tokenizer_config.json", true, **options),
303
+ Utils::Hub.get_model_json(pretrained_model_name_or_path, "tokenizer_config.json", true, **options)
150
304
  ]
151
305
 
152
306
  # Override legacy option if `options.legacy` is not null
@@ -0,0 +1,18 @@
1
+ module Informers
2
+ module Utils
3
+ def self.read_audio(input, sampling_rate)
4
+ data =
5
+ if input.is_a?(URI)
6
+ require "open-uri"
7
+
8
+ input.read
9
+ elsif input.is_a?(String)
10
+ File.binread(input)
11
+ else
12
+ raise ArgumentError, "Unsupported input type: #{input.class.name}"
13
+ end
14
+
15
+ ffmpeg_read(data, sampling_rate)
16
+ end
17
+ end
18
+ end
@@ -3,5 +3,9 @@ module Informers
3
3
  def self.dispatch_callback(progress_callback, data)
4
4
  progress_callback.(data) if progress_callback
5
5
  end
6
+
7
+ def self.calculate_reflect_offset(i, w)
8
+ ((i + w) % (2 * w) - w).abs
9
+ end
6
10
  end
7
11
  end
@@ -0,0 +1,45 @@
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Informers
16
+ module Utils
17
+ # from the Transformers Python library
18
+ def self.ffmpeg_read(data, sampling_rate)
19
+ ar = "#{sampling_rate}"
20
+ ac = "1"
21
+ format_for_conversion = "f32le"
22
+ ffmpeg_command = [
23
+ "ffmpeg",
24
+ "-i",
25
+ "pipe:0",
26
+ "-ac",
27
+ ac,
28
+ "-ar",
29
+ ar,
30
+ "-f",
31
+ format_for_conversion,
32
+ "-hide_banner",
33
+ "-loglevel",
34
+ "quiet",
35
+ "pipe:1"
36
+ ]
37
+
38
+ stdout, status = Open3.capture2(*ffmpeg_command, stdin_data: data)
39
+ if !status.success?
40
+ raise Error, "ffmpeg was not found but is required to load audio files from filename"
41
+ end
42
+ stdout.unpack("e*")
43
+ end
44
+ end
45
+ end
@@ -0,0 +1,294 @@
1
+ module Informers
2
+ module Utils
3
+ class GenerationConfig
4
+ def initialize(kwargs)
5
+ @config = {}
6
+
7
+ # Parameters that control the length of the output
8
+ @config["max_length"] = kwargs["max_length"] || 20
9
+ @config["max_new_tokens"] = kwargs["max_new_tokens"]
10
+ @config["min_length"] = kwargs["min_length"] || 0
11
+ @config["min_new_tokens"] = kwargs["min_new_tokens"]
12
+ @config["early_stopping"] = kwargs["early_stopping"] || false
13
+ @config["max_time"] = kwargs["max_time"]
14
+
15
+ # Parameters that control the generation strategy used
16
+ @config["do_sample"] = kwargs["do_sample"] || false
17
+ @config["num_beams"] = kwargs["num_beams"] || 1
18
+ @config["num_beam_groups"] = kwargs["num_beam_groups"] || 1
19
+ @config["penalty_alpha"] = kwargs["penalty_alpha"]
20
+ @config["use_cache"] = kwargs.fetch("use_cache", true)
21
+
22
+ # Parameters for manipulation of the model output logits
23
+ @config["temperature"] = kwargs["temperature"] || 1.0
24
+ @config["top_k"] = kwargs["top_k"] || 50
25
+ @config["top_p"] = kwargs["top_p"] || 1.0
26
+ @config["typical_p"] = kwargs["typical_p"] || 1.0
27
+ @config["epsilon_cutoff"] = kwargs["epsilon_cutoff"] || 0.0
28
+ @config["eta_cutoff"] = kwargs["eta_cutoff"] || 0.0
29
+ @config["diversity_penalty"] = kwargs["diversity_penalty"] || 0.0
30
+ @config["repetition_penalty"] = kwargs["repetition_penalty"] || 1.0
31
+ @config["encoder_repetition_penalty"] = kwargs["encoder_repetition_penalty"] || 1.0
32
+ @config["length_penalty"] = kwargs["length_penalty"] || 1.0
33
+ @config["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] || 0
34
+ @config["bad_words_ids"] = kwargs["bad_words_ids"]
35
+ @config["force_words_ids"] = kwargs["force_words_ids"]
36
+ @config["renormalize_logits"] = kwargs["renormalize_logits"] || false
37
+ @config["constraints"] = kwargs["constraints"]
38
+ @config["forced_bos_token_id"] = kwargs["forced_bos_token_id"]
39
+ @config["forced_eos_token_id"] = kwargs["forced_eos_token_id"]
40
+ @config["remove_invalid_values"] = kwargs["remove_invalid_values"] || false
41
+ @config["exponential_decay_length_penalty"] = kwargs["exponential_decay_length_penalty"]
42
+ @config["suppress_tokens"] = kwargs["suppress_tokens"]
43
+ @config["begin_suppress_tokens"] = kwargs["begin_suppress_tokens"]
44
+ @config["forced_decoder_ids"] = kwargs["forced_decoder_ids"]
45
+
46
+ # Parameters that define the output variables of `generate`
47
+ @config["num_return_sequences"] = kwargs["num_return_sequences"] || 1
48
+ @config["output_attentions"] = kwargs["output_attentions"] || false
49
+ @config["output_hidden_states"] = kwargs["output_hidden_states"] || false
50
+ @config["output_scores"] = kwargs["output_scores"] || false
51
+ @config["return_dict_in_generate"] = kwargs["return_dict_in_generate"] || false
52
+
53
+ # Special tokens that can be used at generation time
54
+ @config["pad_token_id"] = kwargs["pad_token_id"]
55
+ @config["bos_token_id"] = kwargs["bos_token_id"]
56
+ @config["eos_token_id"] = kwargs["eos_token_id"]
57
+
58
+ # Generation parameters exclusive to encoder-decoder models
59
+ @config["encoder_no_repeat_ngram_size"] = kwargs["encoder_no_repeat_ngram_size"] || 0
60
+ @config["decoder_start_token_id"] = kwargs["decoder_start_token_id"]
61
+
62
+ # Wild card
63
+ @generation_kwargs = kwargs["generation_kwargs"] || {}
64
+ end
65
+
66
+ def [](key)
67
+ @config[key.to_s]
68
+ end
69
+
70
+ def merge!(config)
71
+ @config.merge!(config)
72
+ end
73
+ end
74
+
75
+ class Sampler
76
+ def initialize(generation_config)
77
+ super()
78
+ @generation_config = generation_config
79
+ end
80
+
81
+ def call(logits, index = -1)
82
+ # Sample from logits, of dims [batch, sequence_length, vocab_size].
83
+ # If index is specified, sample from [batch, index, vocab_size].
84
+ sample(logits, index)
85
+ end
86
+
87
+ def get_logits(logits, index)
88
+ vocab_size = Utils.dims(logits)[-1]
89
+
90
+ logs = logits.flatten
91
+
92
+ if index == -1
93
+ logs = logs.last(vocab_size)
94
+ else
95
+ raise Todo
96
+ end
97
+
98
+ # add temperature
99
+ if @generation_config["temperature"] > 0
100
+ logs = logs.map { |x| x / @generation_config["temperature"] }
101
+ end
102
+ logs
103
+ end
104
+
105
+ def self.get_sampler(generation_config)
106
+ if generation_config[:do_sample]
107
+ MultinomialSampler.new(generation_config)
108
+ elsif generation_config[:num_beams] > 1
109
+ BeamSearchSampler.new(generation_config)
110
+ else
111
+ if generation_config[:num_return_sequences] > 1
112
+ raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
113
+ end
114
+ GreedySampler.new(generation_config)
115
+ end
116
+ end
117
+ end
118
+
119
+ class GreedySampler < Sampler
120
+ def sample(logits, index = -1)
121
+ # NOTE: no need to do log_softmax here since we only take the maximum
122
+ logs = get_logits(logits, index)
123
+ argmax = Utils.max(logs)[1]
124
+
125
+ # Note: score is meaningless in this context, since we are performing
126
+ # greedy search (p = 1 => log(p) = 0)
127
+ [
128
+ [argmax, 0]
129
+ ]
130
+ end
131
+ end
132
+
133
+ class BeamSearchSampler < Sampler
134
+ def sample(logits, index = -1)
135
+ k = Utils.dims(logits)[-1] # defaults to vocab size
136
+ if @generation_config["top_k"] > 0
137
+ k = [@generation_config["top_k"], k].min
138
+ end
139
+
140
+ # Get logits of nth token
141
+ logs = get_logits(logits, index)
142
+
143
+ # Get top k tokens
144
+ top_logits = Utils.get_top_items(logs, k)
145
+
146
+ # Compute softmax over logits
147
+ probabilities = Utils.softmax(top_logits.map { |x| x[1] })
148
+
149
+ Array.new(@generation_config["num_beams"]) do |i|
150
+ [
151
+ top_logits[i][0],
152
+ Math.log(probabilities[i])
153
+ ]
154
+ end
155
+ end
156
+ end
157
+
158
+ class LogitsProcessorList
159
+ def initialize
160
+ super
161
+ @processors = []
162
+ end
163
+
164
+ def push(item)
165
+ @processors << item
166
+ end
167
+
168
+ def concat(items)
169
+ @processors.concat(items)
170
+ end
171
+
172
+ def call(input_ids, batched_logits)
173
+ # NOTE: This is different from the Python code, since vanilla Ruby does not support vectorized operations.
174
+ # As a result, we apply each processor to each item in the batch.
175
+ batched_logits.each do |logits|
176
+ # Modifies logits inplace
177
+ @processors.each do |func|
178
+ func.(input_ids, logits)
179
+ end
180
+ end
181
+ end
182
+
183
+ def to_ary
184
+ @processors
185
+ end
186
+ end
187
+
188
+ class LogitsProcessor
189
+ end
190
+
191
+ class NoRepeatNGramLogitsProcessor < LogitsProcessor
192
+ def initialize(no_repeat_ngram_size)
193
+ super()
194
+ @no_repeat_ngram_size = no_repeat_ngram_size
195
+ end
196
+
197
+ def get_ngrams(prev_input_ids)
198
+ cur_len = prev_input_ids.length
199
+
200
+ ngrams = []
201
+ j = 0
202
+ while j < cur_len + 1 - @no_repeat_ngram_size
203
+ ngram = []
204
+ @no_repeat_ngram_size.times do |k|
205
+ ngram << prev_input_ids[j + k]
206
+ end
207
+ ngrams << ngram
208
+ j += 1
209
+ end
210
+
211
+ generated_ngram = {}
212
+ ngrams.each do |ngram|
213
+ prev_ngram = ngram.slice(0, ngram.length - 1)
214
+ prev_ngram_key = JSON.generate(prev_ngram)
215
+ prev_ngram_value = generated_ngram[prev_ngram_key] || []
216
+ prev_ngram_value << ngram[ngram.length - 1]
217
+ generated_ngram[prev_ngram_key] = prev_ngram_value
218
+ end
219
+ generated_ngram
220
+ end
221
+
222
+ def get_generated_ngrams(banned_ngrams, prev_input_ids)
223
+ ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)
224
+ banned = banned_ngrams[JSON.generate(ngram_idx)] || []
225
+ banned
226
+ end
227
+
228
+ def calc_banned_ngram_tokens(prev_input_ids)
229
+ banned_tokens = []
230
+ if prev_input_ids.length + 1 < @no_repeat_ngram_size
231
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
232
+ banned_tokens
233
+ else
234
+ generated_ngrams = get_ngrams(prev_input_ids)
235
+ banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)
236
+ banned_tokens
237
+ end
238
+ end
239
+
240
+ def call(input_ids, logits)
241
+ banned_tokens = calc_banned_ngram_tokens(input_ids)
242
+
243
+ banned_tokens.each do |token|
244
+ logits[token] = -Float::INFINITY
245
+ end
246
+ logits
247
+ end
248
+ end
249
+
250
+ class MinLengthLogitsProcessor < LogitsProcessor
251
+ def initialize(min_length, eos_token_id)
252
+ super()
253
+ @min_length = min_length
254
+ @eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
255
+ end
256
+
257
+ def call(input_ids, logits)
258
+ if input_ids.length < @min_length
259
+ @eos_token_id.each do |eos_token|
260
+ logits[eos_token] = -Float::INFINITY
261
+ end
262
+ end
263
+
264
+ logits
265
+ end
266
+ end
267
+
268
+ class ForcedBOSTokenLogitsProcessor < LogitsProcessor
269
+ def initialize(bos_token_id)
270
+ super()
271
+ @bos_token_id = bos_token_id
272
+ end
273
+
274
+ def call(input_ids, logits)
275
+ if input_ids.length == 1
276
+ logits.map! { -Float::INFINITY }
277
+ logits[@bos_token_id] = 0
278
+ end
279
+ logits
280
+ end
281
+ end
282
+
283
+ class ForcedEOSTokenLogitsProcessor < LogitsProcessor
284
+ def initialize(max_length, forced_eos_token_id)
285
+ super()
286
+ @max_length = max_length
287
+ @forced_eos_token_id = forced_eos_token_id
288
+ end
289
+
290
+ def call(input_ids, logits)
291
+ end
292
+ end
293
+ end
294
+ end
@@ -0,0 +1,116 @@
1
+ module Informers
2
+ module Utils
3
+ class RawImage
4
+ RESAMPLING_MAPPING = {
5
+ 0 => "nearest",
6
+ 1 => "lanczos",
7
+ 2 => "bilinear",
8
+ 3 => "bicubic",
9
+ 4 => "box",
10
+ 5 => "hamming"
11
+ }
12
+
13
+ attr_reader :image, :width, :height, :channels
14
+
15
+ def initialize(image)
16
+ @image = image
17
+ @width = image.width
18
+ @height = image.height
19
+ @channels = image.bands
20
+ end
21
+
22
+ def data
23
+ @image.write_to_memory.unpack("C*")
24
+ end
25
+
26
+ def size
27
+ [@width, @height]
28
+ end
29
+
30
+ def resize(width, height, resample: 2)
31
+ resample_method = RESAMPLING_MAPPING[resample] || resample
32
+
33
+ case resample_method
34
+ when "bilinear", "bicubic"
35
+ img =
36
+ @image.affine(
37
+ [width / @width.to_f, 0, 0, height / @height.to_f],
38
+ interpolate: Vips::Interpolate.new(resample_method.to_sym)
39
+ )
40
+ else
41
+ raise Todo
42
+ end
43
+
44
+ RawImage.new(img)
45
+ end
46
+
47
+ def center_crop(crop_width, crop_height)
48
+ # If the image is already the desired size, return it
49
+ if @width == crop_width && @height == crop_height
50
+ return self
51
+ end
52
+
53
+ # Determine bounds of the image in the new canvas
54
+ width_offset = (@width - crop_width) / 2.0
55
+ height_offset = (@height - crop_height) / 2.0
56
+
57
+ if width_offset >= 0 && height_offset >= 0
58
+ # Cropped image lies entirely within the original image
59
+ img = @image.crop(
60
+ width_offset.floor,
61
+ height_offset.floor,
62
+ crop_width,
63
+ crop_height
64
+ )
65
+ elsif width_offset <= 0 && height_offset <= 0
66
+ raise Todo
67
+ else
68
+ raise Todo
69
+ end
70
+
71
+ RawImage.new(img)
72
+ end
73
+
74
+ def rgb
75
+ if @channels == 3
76
+ return self
77
+ end
78
+
79
+ raise Todo
80
+ end
81
+
82
+ def save(path)
83
+ @image.write_to_file(path)
84
+ end
85
+
86
+ def self.read(input)
87
+ if input.is_a?(RawImage)
88
+ input
89
+ elsif input.is_a?(URI)
90
+ require "open-uri"
91
+
92
+ RawImage.new(Vips::Image.new_from_buffer(input.read, ""))
93
+ elsif input.is_a?(String)
94
+ RawImage.new(Vips::Image.new_from_file(input))
95
+ else
96
+ raise ArgumentError, "Unsupported input type: #{input.class.name}"
97
+ end
98
+ end
99
+
100
+ def self.from_array(input)
101
+ c, h, w = Utils.dims(input)
102
+ pixel_data = Array.new(w * h * c)
103
+
104
+ input.each_with_index do |cv, ci|
105
+ cv.each_with_index do |hv, hi|
106
+ hv.each_with_index do |v, wi|
107
+ pixel_data[(hi * w * c) + (wi * c) + ci] = v
108
+ end
109
+ end
110
+ end
111
+
112
+ RawImage.new(Vips::Image.new_from_memory_copy(pixel_data.pack("C*"), w, h, c, :uchar))
113
+ end
114
+ end
115
+ end
116
+ end