informers 1.0.2 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,65 @@
1
1
  module Informers
2
2
  class PreTrainedTokenizer
3
- attr_reader :sep_token_id
3
+ attr_reader :mask_token, :mask_token_id, :sep_token_id
4
4
 
5
5
  def initialize(tokenizer_json, tokenizer_config)
6
6
  super()
7
7
 
8
+ @tokenizer_config = tokenizer_config
9
+
8
10
  @tokenizer = Tokenizers::Tokenizer.from_file(tokenizer_json)
9
11
 
10
- @sep_token = tokenizer_config["sep_token"]
11
- @sep_token_id = @tokenizer.token_to_id(@sep_token)
12
+ # Add added_tokens to model
13
+ @special_tokens = []
14
+ @all_special_ids = []
15
+
16
+ @added_tokens = []
17
+ @tokenizer.added_tokens_decoder.each do |id, token|
18
+ @added_tokens << token
19
+
20
+ if token.special
21
+ @special_tokens << token.content
22
+ @all_special_ids << id
23
+ end
24
+ end
25
+
26
+ # Update additional_special_tokens
27
+ @additional_special_tokens = tokenizer_config["additional_special_tokens"] || []
28
+ @special_tokens.concat(@additional_special_tokens)
29
+
30
+ @mask_token = get_token("mask_token")
31
+ @mask_token_id = @tokenizer.token_to_id(@mask_token) if @mask_token
32
+
33
+ @sep_token = get_token("sep_token")
34
+ @sep_token_id = @tokenizer.token_to_id(@sep_token) if @sep_token
12
35
 
13
36
  @model_max_length = tokenizer_config["model_max_length"]
37
+
38
+ # for donut-base-finetuned-docvqa
39
+ if @model_max_length && @model_max_length > (1 << 63)
40
+ @model_max_length = 1 << 63
41
+ end
42
+ end
43
+
44
+ def get_token(*keys)
45
+ keys.each do |key|
46
+ item = @tokenizer_config[key]
47
+ if !item
48
+ next
49
+ end
50
+
51
+ if item.is_a?(Hash)
52
+ if item["__type"] == "AddedToken"
53
+ return item["content"]
54
+ else
55
+ raise Error, "Unknown token: #{item}"
56
+ end
57
+ else
58
+ return item
59
+ end
60
+ end
61
+
62
+ nil
14
63
  end
15
64
 
16
65
  def call(
@@ -76,6 +125,22 @@ module Informers
76
125
  def convert_tokens_to_string(tokens)
77
126
  @tokenizer.decoder.decode(tokens)
78
127
  end
128
+
129
+ def convert_tokens_to_ids(tokens)
130
+ tokens.map { |t| @tokenizer.token_to_id(t) }
131
+ end
132
+
133
+ def id_to_token(id)
134
+ @tokenizer.id_to_token(id)
135
+ end
136
+
137
+ def batch_decode(batch, **decode_args)
138
+ @tokenizer.decode_batch(batch, **decode_args)
139
+ end
140
+
141
+ def padding_side=(side)
142
+ @tokenizer.enable_padding(direction: side)
143
+ end
79
144
  end
80
145
 
81
146
  class BertTokenizer < PreTrainedTokenizer
@@ -91,11 +156,108 @@ module Informers
91
156
  class DistilBertTokenizer < PreTrainedTokenizer
92
157
  end
93
158
 
159
+ class T5Tokenizer < PreTrainedTokenizer
160
+ end
161
+
162
+ class GPT2Tokenizer < PreTrainedTokenizer
163
+ # _default_chat_template = `{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}`
164
+ end
165
+
166
+ class BartTokenizer < PreTrainedTokenizer
167
+ end
168
+
169
+ class RobertaTokenizer < PreTrainedTokenizer
170
+ end
171
+
172
+ class XLMRobertaTokenizer < PreTrainedTokenizer
173
+ end
174
+
175
+ class MPNetTokenizer < PreTrainedTokenizer
176
+ end
177
+
178
+ class CLIPTokenizer < PreTrainedTokenizer
179
+ end
180
+
181
+ class NllbTokenizer < PreTrainedTokenizer
182
+ attr_reader :language_regex, :language_codes, :lang_to_token
183
+
184
+ def initialize(tokenizer_json, tokenizer_config)
185
+ super(tokenizer_json, tokenizer_config)
186
+
187
+ @language_regex = /^[a-z]{3}_[A-Z][a-z]{3}$/
188
+ @language_codes = @special_tokens.filter { |x| @language_regex.match?(x) }
189
+ @lang_to_token = ->(x) { x } # Identity function
190
+ end
191
+
192
+ def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
193
+ Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
194
+ end
195
+ end
196
+
197
+ class M2M100Tokenizer < PreTrainedTokenizer
198
+ attr_reader :language_regex, :language_codes, :lang_to_token
199
+
200
+ def initialize(tokenizer_json, tokenizer_config)
201
+ super(tokenizer_json, tokenizer_config)
202
+
203
+ @language_regex = /^__[a-z]{2,3}__$/
204
+ @language_codes = @special_tokens
205
+ .filter { |x| @language_regex.match?(x) }
206
+ .map { |x| x.slice(2, -2) }
207
+ @lang_to_token = ->(x) { "__#{x}__" }
208
+ end
209
+
210
+ def _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs)
211
+ Utils._build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs)
212
+ end
213
+ end
214
+
215
+ module Utils
216
+ def self._build_translation_inputs(slf, raw_inputs, tokenizer_options, generate_kwargs)
217
+ if !slf.respond_to?(:language_codes) || !slf.language_codes.is_a?(Array)
218
+ raise Error, "Tokenizer must have `language_codes` attribute set and it should be an array of language ids."
219
+ end
220
+ if !slf.respond_to?(:language_regex) || !slf.language_regex.is_a?(Regexp)
221
+ raise Error, "Tokenizer must have `language_regex` attribute set and it should be a regular expression."
222
+ end
223
+ if !slf.respond_to?(:lang_to_token) || !slf.lang_to_token.respond_to?(:call)
224
+ raise Error, "Tokenizer must have `lang_to_token` attribute set and it should be a function."
225
+ end
226
+ src_lang_token = generate_kwargs[:src_lang]
227
+ tgt_lang_token = generate_kwargs[:tgt_lang]
228
+
229
+ if !slf.language_codes.include?(tgt_lang_token)
230
+ raise Error, "Target language code #{tgt_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
231
+ end
232
+
233
+ if !src_lang_token.nil?
234
+ # Check that the source language is valid:
235
+ if !slf.language_codes.include?(src_lang_token)
236
+ raise Error, "Source language code #{src_lang_token.inspect} is not valid. Must be one of: #{slf.language_codes.join(", ")}"
237
+ end
238
+ end
239
+
240
+ # Override the `forced_bos_token_id` to force the correct language
241
+ generate_kwargs["forced_bos_token_id"] = slf.convert_tokens_to_ids([slf.lang_to_token.(tgt_lang_token)])[0]
242
+
243
+ slf.(raw_inputs, **tokenizer_options)
244
+ end
245
+ end
246
+
94
247
  class AutoTokenizer
95
248
  TOKENIZER_CLASS_MAPPING = {
249
+ "T5Tokenizer" => T5Tokenizer,
96
250
  "BertTokenizer" => BertTokenizer,
97
251
  "DebertaV2Tokenizer" => DebertaV2Tokenizer,
98
- "DistilBertTokenizer" => DistilBertTokenizer
252
+ "DistilBertTokenizer" => DistilBertTokenizer,
253
+ "BartTokenizer" => BartTokenizer,
254
+ "RobertaTokenizer" => RobertaTokenizer,
255
+ "XLMRobertaTokenizer" => XLMRobertaTokenizer,
256
+ "MPNetTokenizer" => MPNetTokenizer,
257
+ "CLIPTokenizer" => CLIPTokenizer,
258
+ "GPT2Tokenizer" => GPT2Tokenizer,
259
+ "NllbTokenizer" => NllbTokenizer,
260
+ "M2M100Tokenizer" => M2M100Tokenizer
99
261
  }
100
262
 
101
263
  def self.from_pretrained(
@@ -3,5 +3,9 @@ module Informers
3
3
  def self.dispatch_callback(progress_callback, data)
4
4
  progress_callback.(data) if progress_callback
5
5
  end
6
+
7
+ def self.calculate_reflect_offset(i, w)
8
+ ((i + w) % (2 * w) - w).abs
9
+ end
6
10
  end
7
11
  end
@@ -0,0 +1,294 @@
1
+ module Informers
2
+ module Utils
3
+ class GenerationConfig
4
+ def initialize(kwargs)
5
+ @config = {}
6
+
7
+ # Parameters that control the length of the output
8
+ @config["max_length"] = kwargs["max_length"] || 20
9
+ @config["max_new_tokens"] = kwargs["max_new_tokens"]
10
+ @config["min_length"] = kwargs["min_length"] || 0
11
+ @config["min_new_tokens"] = kwargs["min_new_tokens"]
12
+ @config["early_stopping"] = kwargs["early_stopping"] || false
13
+ @config["max_time"] = kwargs["max_time"]
14
+
15
+ # Parameters that control the generation strategy used
16
+ @config["do_sample"] = kwargs["do_sample"] || false
17
+ @config["num_beams"] = kwargs["num_beams"] || 1
18
+ @config["num_beam_groups"] = kwargs["num_beam_groups"] || 1
19
+ @config["penalty_alpha"] = kwargs["penalty_alpha"]
20
+ @config["use_cache"] = kwargs.fetch("use_cache", true)
21
+
22
+ # Parameters for manipulation of the model output logits
23
+ @config["temperature"] = kwargs["temperature"] || 1.0
24
+ @config["top_k"] = kwargs["top_k"] || 50
25
+ @config["top_p"] = kwargs["top_p"] || 1.0
26
+ @config["typical_p"] = kwargs["typical_p"] || 1.0
27
+ @config["epsilon_cutoff"] = kwargs["epsilon_cutoff"] || 0.0
28
+ @config["eta_cutoff"] = kwargs["eta_cutoff"] || 0.0
29
+ @config["diversity_penalty"] = kwargs["diversity_penalty"] || 0.0
30
+ @config["repetition_penalty"] = kwargs["repetition_penalty"] || 1.0
31
+ @config["encoder_repetition_penalty"] = kwargs["encoder_repetition_penalty"] || 1.0
32
+ @config["length_penalty"] = kwargs["length_penalty"] || 1.0
33
+ @config["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] || 0
34
+ @config["bad_words_ids"] = kwargs["bad_words_ids"]
35
+ @config["force_words_ids"] = kwargs["force_words_ids"]
36
+ @config["renormalize_logits"] = kwargs["renormalize_logits"] || false
37
+ @config["constraints"] = kwargs["constraints"]
38
+ @config["forced_bos_token_id"] = kwargs["forced_bos_token_id"]
39
+ @config["forced_eos_token_id"] = kwargs["forced_eos_token_id"]
40
+ @config["remove_invalid_values"] = kwargs["remove_invalid_values"] || false
41
+ @config["exponential_decay_length_penalty"] = kwargs["exponential_decay_length_penalty"]
42
+ @config["suppress_tokens"] = kwargs["suppress_tokens"]
43
+ @config["begin_suppress_tokens"] = kwargs["begin_suppress_tokens"]
44
+ @config["forced_decoder_ids"] = kwargs["forced_decoder_ids"]
45
+
46
+ # Parameters that define the output variables of `generate`
47
+ @config["num_return_sequences"] = kwargs["num_return_sequences"] || 1
48
+ @config["output_attentions"] = kwargs["output_attentions"] || false
49
+ @config["output_hidden_states"] = kwargs["output_hidden_states"] || false
50
+ @config["output_scores"] = kwargs["output_scores"] || false
51
+ @config["return_dict_in_generate"] = kwargs["return_dict_in_generate"] || false
52
+
53
+ # Special tokens that can be used at generation time
54
+ @config["pad_token_id"] = kwargs["pad_token_id"]
55
+ @config["bos_token_id"] = kwargs["bos_token_id"]
56
+ @config["eos_token_id"] = kwargs["eos_token_id"]
57
+
58
+ # Generation parameters exclusive to encoder-decoder models
59
+ @config["encoder_no_repeat_ngram_size"] = kwargs["encoder_no_repeat_ngram_size"] || 0
60
+ @config["decoder_start_token_id"] = kwargs["decoder_start_token_id"]
61
+
62
+ # Wild card
63
+ @generation_kwargs = kwargs["generation_kwargs"] || {}
64
+ end
65
+
66
+ def [](key)
67
+ @config[key.to_s]
68
+ end
69
+
70
+ def merge!(config)
71
+ @config.merge!(config)
72
+ end
73
+ end
74
+
75
+ class Sampler
76
+ def initialize(generation_config)
77
+ super()
78
+ @generation_config = generation_config
79
+ end
80
+
81
+ def call(logits, index = -1)
82
+ # Sample from logits, of dims [batch, sequence_length, vocab_size].
83
+ # If index is specified, sample from [batch, index, vocab_size].
84
+ sample(logits, index)
85
+ end
86
+
87
+ def get_logits(logits, index)
88
+ vocab_size = Utils.dims(logits)[-1]
89
+
90
+ logs = logits.flatten
91
+
92
+ if index == -1
93
+ logs = logs.last(vocab_size)
94
+ else
95
+ raise Todo
96
+ end
97
+
98
+ # add temperature
99
+ if @generation_config["temperature"] > 0
100
+ logs = logs.map { |x| x / @generation_config["temperature"] }
101
+ end
102
+ logs
103
+ end
104
+
105
+ def self.get_sampler(generation_config)
106
+ if generation_config[:do_sample]
107
+ MultinomialSampler.new(generation_config)
108
+ elsif generation_config[:num_beams] > 1
109
+ BeamSearchSampler.new(generation_config)
110
+ else
111
+ if generation_config[:num_return_sequences] > 1
112
+ raise Error, "num_return_sequences has to be 1 when doing greedy search, but is #{generation_config[:num_return_sequences]}."
113
+ end
114
+ GreedySampler.new(generation_config)
115
+ end
116
+ end
117
+ end
118
+
119
+ class GreedySampler < Sampler
120
+ def sample(logits, index = -1)
121
+ # NOTE: no need to do log_softmax here since we only take the maximum
122
+ logs = get_logits(logits, index)
123
+ argmax = Utils.max(logs)[1]
124
+
125
+ # Note: score is meaningless in this context, since we are performing
126
+ # greedy search (p = 1 => log(p) = 0)
127
+ [
128
+ [argmax, 0]
129
+ ]
130
+ end
131
+ end
132
+
133
+ class BeamSearchSampler < Sampler
134
+ def sample(logits, index = -1)
135
+ k = Utils.dims(logits)[-1] # defaults to vocab size
136
+ if @generation_config["top_k"] > 0
137
+ k = [@generation_config["top_k"], k].min
138
+ end
139
+
140
+ # Get logits of nth token
141
+ logs = get_logits(logits, index)
142
+
143
+ # Get top k tokens
144
+ top_logits = Utils.get_top_items(logs, k)
145
+
146
+ # Compute softmax over logits
147
+ probabilities = Utils.softmax(top_logits.map { |x| x[1] })
148
+
149
+ Array.new(@generation_config["num_beams"]) do |i|
150
+ [
151
+ top_logits[i][0],
152
+ Math.log(probabilities[i])
153
+ ]
154
+ end
155
+ end
156
+ end
157
+
158
+ class LogitsProcessorList
159
+ def initialize
160
+ super
161
+ @processors = []
162
+ end
163
+
164
+ def push(item)
165
+ @processors << item
166
+ end
167
+
168
+ def concat(items)
169
+ @processors.concat(items)
170
+ end
171
+
172
+ def call(input_ids, batched_logits)
173
+ # NOTE: This is different from the Python code, since vanilla Ruby does not support vectorized operations.
174
+ # As a result, we apply each processor to each item in the batch.
175
+ batched_logits.each do |logits|
176
+ # Modifies logits inplace
177
+ @processors.each do |func|
178
+ func.(input_ids, logits)
179
+ end
180
+ end
181
+ end
182
+
183
+ def to_ary
184
+ @processors
185
+ end
186
+ end
187
+
188
+ class LogitsProcessor
189
+ end
190
+
191
+ class NoRepeatNGramLogitsProcessor < LogitsProcessor
192
+ def initialize(no_repeat_ngram_size)
193
+ super()
194
+ @no_repeat_ngram_size = no_repeat_ngram_size
195
+ end
196
+
197
+ def get_ngrams(prev_input_ids)
198
+ cur_len = prev_input_ids.length
199
+
200
+ ngrams = []
201
+ j = 0
202
+ while j < cur_len + 1 - @no_repeat_ngram_size
203
+ ngram = []
204
+ @no_repeat_ngram_size.times do |k|
205
+ ngram << prev_input_ids[j + k]
206
+ end
207
+ ngrams << ngram
208
+ j += 1
209
+ end
210
+
211
+ generated_ngram = {}
212
+ ngrams.each do |ngram|
213
+ prev_ngram = ngram.slice(0, ngram.length - 1)
214
+ prev_ngram_key = JSON.generate(prev_ngram)
215
+ prev_ngram_value = generated_ngram[prev_ngram_key] || []
216
+ prev_ngram_value << ngram[ngram.length - 1]
217
+ generated_ngram[prev_ngram_key] = prev_ngram_value
218
+ end
219
+ generated_ngram
220
+ end
221
+
222
+ def get_generated_ngrams(banned_ngrams, prev_input_ids)
223
+ ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)
224
+ banned = banned_ngrams[JSON.generate(ngram_idx)] || []
225
+ banned
226
+ end
227
+
228
+ def calc_banned_ngram_tokens(prev_input_ids)
229
+ banned_tokens = []
230
+ if prev_input_ids.length + 1 < @no_repeat_ngram_size
231
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
232
+ banned_tokens
233
+ else
234
+ generated_ngrams = get_ngrams(prev_input_ids)
235
+ banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)
236
+ banned_tokens
237
+ end
238
+ end
239
+
240
+ def call(input_ids, logits)
241
+ banned_tokens = calc_banned_ngram_tokens(input_ids)
242
+
243
+ banned_tokens.each do |token|
244
+ logits[token] = -Float::INFINITY
245
+ end
246
+ logits
247
+ end
248
+ end
249
+
250
+ class MinLengthLogitsProcessor < LogitsProcessor
251
+ def initialize(min_length, eos_token_id)
252
+ super()
253
+ @min_length = min_length
254
+ @eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
255
+ end
256
+
257
+ def call(input_ids, logits)
258
+ if input_ids.length < @min_length
259
+ @eos_token_id.each do |eos_token|
260
+ logits[eos_token] = -Float::INFINITY
261
+ end
262
+ end
263
+
264
+ logits
265
+ end
266
+ end
267
+
268
+ class ForcedBOSTokenLogitsProcessor < LogitsProcessor
269
+ def initialize(bos_token_id)
270
+ super()
271
+ @bos_token_id = bos_token_id
272
+ end
273
+
274
+ def call(input_ids, logits)
275
+ if input_ids.length == 1
276
+ logits.map! { -Float::INFINITY }
277
+ logits[@bos_token_id] = 0
278
+ end
279
+ logits
280
+ end
281
+ end
282
+
283
+ class ForcedEOSTokenLogitsProcessor < LogitsProcessor
284
+ def initialize(max_length, forced_eos_token_id)
285
+ super()
286
+ @max_length = max_length
287
+ @forced_eos_token_id = forced_eos_token_id
288
+ end
289
+
290
+ def call(input_ids, logits)
291
+ end
292
+ end
293
+ end
294
+ end
@@ -0,0 +1,116 @@
1
+ module Informers
2
+ module Utils
3
+ class RawImage
4
+ RESAMPLING_MAPPING = {
5
+ 0 => "nearest",
6
+ 1 => "lanczos",
7
+ 2 => "bilinear",
8
+ 3 => "bicubic",
9
+ 4 => "box",
10
+ 5 => "hamming",
11
+ }
12
+
13
+ attr_reader :image, :width, :height, :channels
14
+
15
+ def initialize(image)
16
+ @image = image
17
+ @width = image.width
18
+ @height = image.height
19
+ @channels = image.bands
20
+ end
21
+
22
+ def data
23
+ @image.write_to_memory.unpack("C*")
24
+ end
25
+
26
+ def size
27
+ [@width, @height]
28
+ end
29
+
30
+ def resize(width, height, resample: 2)
31
+ resample_method = RESAMPLING_MAPPING[resample] || resample
32
+
33
+ case resample_method
34
+ when "bilinear", "bicubic"
35
+ img =
36
+ @image.affine(
37
+ [width / @width.to_f, 0, 0, height / @height.to_f],
38
+ interpolate: Vips::Interpolate.new(resample_method.to_sym)
39
+ )
40
+ else
41
+ raise Todo
42
+ end
43
+
44
+ RawImage.new(img)
45
+ end
46
+
47
+ def center_crop(crop_width, crop_height)
48
+ # If the image is already the desired size, return it
49
+ if @width == crop_width && @height == crop_height
50
+ return self
51
+ end
52
+
53
+ # Determine bounds of the image in the new canvas
54
+ width_offset = (@width - crop_width) / 2.0
55
+ height_offset = (@height - crop_height) / 2.0
56
+
57
+ if width_offset >= 0 && height_offset >= 0
58
+ # Cropped image lies entirely within the original image
59
+ img = @image.crop(
60
+ width_offset.floor,
61
+ height_offset.floor,
62
+ crop_width,
63
+ crop_height
64
+ )
65
+ elsif width_offset <= 0 && height_offset <= 0
66
+ raise Todo
67
+ else
68
+ raise Todo
69
+ end
70
+
71
+ RawImage.new(img)
72
+ end
73
+
74
+ def rgb
75
+ if @channels == 3
76
+ return self
77
+ end
78
+
79
+ raise Todo
80
+ end
81
+
82
+ def save(path)
83
+ @image.write_to_file(path)
84
+ end
85
+
86
+ def self.read(input)
87
+ if input.is_a?(RawImage)
88
+ input
89
+ elsif input.is_a?(URI)
90
+ require "open-uri"
91
+
92
+ RawImage.new(Vips::Image.new_from_buffer(input.read, ""))
93
+ elsif input.is_a?(String)
94
+ RawImage.new(Vips::Image.new_from_file(input))
95
+ else
96
+ raise ArgumentError, "Unsupported input type: #{input.class.name}"
97
+ end
98
+ end
99
+
100
+ def self.from_array(input)
101
+ c, h, w = Utils.dims(input)
102
+ pixel_data = Array.new(w * h * c)
103
+
104
+ input.each_with_index do |cv, ci|
105
+ cv.each_with_index do |hv, hi|
106
+ hv.each_with_index do |v, wi|
107
+ pixel_data[(hi * w * c) + (wi * c) + ci] = v
108
+ end
109
+ end
110
+ end
111
+
112
+ RawImage.new(Vips::Image.new_from_memory_copy(pixel_data.pack("C*"), w, h, c, :uchar))
113
+ end
114
+ end
115
+ end
116
+ end
@@ -1,5 +1,75 @@
1
1
  module Informers
2
2
  module Utils
3
+ def self.interpolate_data(input, in_shape, out_shape, mode = "bilinear", align_corners = false)
4
+ in_channels, in_height, in_width = in_shape
5
+ out_height, out_width = out_shape
6
+
7
+ # TODO use mode and align_corners
8
+
9
+ # Output image dimensions
10
+ x_scale = out_width / in_width.to_f
11
+ y_scale = out_height / in_height.to_f
12
+
13
+ # Output image
14
+ out_img = Array.new(out_height * out_width * in_channels)
15
+
16
+ # Pre-calculate strides
17
+ in_stride = in_height * in_width;
18
+ out_stride = out_height * out_width;
19
+
20
+ out_height.times do |i|
21
+ out_width.times do |j|
22
+ # Calculate output offset
23
+ out_offset = i * out_width + j
24
+
25
+ # Calculate input pixel coordinates
26
+ x = (j + 0.5) / x_scale - 0.5
27
+ y = (i + 0.5) / y_scale - 0.5
28
+
29
+ # Calculate the four nearest input pixels
30
+ # We also check if the input pixel coordinates are within the image bounds
31
+ x1 = x.floor
32
+ y1 = y.floor
33
+ x2 = [x1 + 1, in_width - 1].min
34
+ y2 = [y1 + 1, in_height - 1].min
35
+
36
+ x1 = [x1, 0].max
37
+ y1 = [y1, 0].max
38
+
39
+ # Calculate the fractional distances between the input pixel and the four nearest pixels
40
+ s = x - x1
41
+ t = y - y1
42
+
43
+ # Perform bilinear interpolation
44
+ w1 = (1 - s) * (1 - t)
45
+ w2 = s * (1 - t)
46
+ w3 = (1 - s) * t
47
+ w4 = s * t
48
+
49
+ # Calculate the four nearest input pixel indices
50
+ y_stride = y1 * in_width
51
+ x_stride = y2 * in_width
52
+ idx1 = y_stride + x1
53
+ idx2 = y_stride + x2
54
+ idx3 = x_stride + x1
55
+ idx4 = x_stride + x2
56
+
57
+ in_channels.times do |k|
58
+ # Calculate channel offset
59
+ c_offset = k * in_stride
60
+
61
+ out_img[k * out_stride + out_offset] =
62
+ w1 * input[c_offset + idx1] +
63
+ w2 * input[c_offset + idx2] +
64
+ w3 * input[c_offset + idx3] +
65
+ w4 * input[c_offset + idx4]
66
+ end
67
+ end
68
+ end
69
+
70
+ out_img
71
+ end
72
+
3
73
  def self.softmax(arr)
4
74
  # Compute the maximum value in the array
5
75
  max_val = arr.max
@@ -17,6 +87,9 @@ module Informers
17
87
  end
18
88
 
19
89
  def self.sigmoid(arr)
90
+ if arr[0].is_a?(Array)
91
+ return arr.map { |a| sigmoid(a) }
92
+ end
20
93
  arr.map { |v| 1 / (1 + Math.exp(-v)) }
21
94
  end
22
95