informers 1.0.3 → 1.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +137 -7
- data/lib/informers/configs.rb +10 -8
- data/lib/informers/model.rb +2 -9
- data/lib/informers/models.rb +1160 -15
- data/lib/informers/pipelines.rb +943 -11
- data/lib/informers/processors.rb +856 -0
- data/lib/informers/tokenizers.rb +159 -5
- data/lib/informers/utils/audio.rb +18 -0
- data/lib/informers/utils/core.rb +4 -0
- data/lib/informers/utils/ffmpeg.rb +45 -0
- data/lib/informers/utils/generation.rb +294 -0
- data/lib/informers/utils/image.rb +116 -0
- data/lib/informers/utils/math.rb +73 -0
- data/lib/informers/utils/tensor.rb +46 -0
- data/lib/informers/version.rb +1 -1
- data/lib/informers.rb +6 -0
- metadata +10 -5
data/lib/informers/models.rb
CHANGED
@@ -44,7 +44,7 @@ module Informers
|
|
44
44
|
end
|
45
45
|
|
46
46
|
const_get(:MODEL_CLASS_MAPPINGS).each do |model_class_mapping|
|
47
|
-
model_info = model_class_mapping[config
|
47
|
+
model_info = model_class_mapping[config[:model_type]]
|
48
48
|
if !model_info
|
49
49
|
next # Item not found in this mapping
|
50
50
|
end
|
@@ -52,15 +52,17 @@ module Informers
|
|
52
52
|
end
|
53
53
|
|
54
54
|
if const_defined?(:BASE_IF_FAIL)
|
55
|
-
warn "Unknown model class #{config
|
55
|
+
warn "Unknown model class #{config[:model_type].inspect}, attempting to construct from base class."
|
56
56
|
PreTrainedModel.from_pretrained(pretrained_model_name_or_path, **options)
|
57
57
|
else
|
58
|
-
raise Error, "Unsupported model type: #{config
|
58
|
+
raise Error, "Unsupported model type: #{config[:model_type]}"
|
59
59
|
end
|
60
60
|
end
|
61
61
|
end
|
62
62
|
|
63
63
|
class PreTrainedModel
|
64
|
+
MAIN_INPUT_NAME = :input_ids
|
65
|
+
|
64
66
|
attr_reader :config
|
65
67
|
|
66
68
|
def initialize(config, session)
|
@@ -76,11 +78,24 @@ module Informers
|
|
76
78
|
|
77
79
|
case model_type
|
78
80
|
when MODEL_TYPES[:DecoderOnly]
|
79
|
-
|
81
|
+
@can_generate = true
|
82
|
+
|
83
|
+
@run_beam = method(:decoder_run_beam)
|
84
|
+
@get_start_beams = method(:decoder_start_beams)
|
85
|
+
@update_beam = method(:decoder_update_beam)
|
86
|
+
@forward = method(:decoder_forward)
|
87
|
+
|
80
88
|
when MODEL_TYPES[:Seq2Seq], MODEL_TYPES[:Vision2Seq]
|
81
|
-
|
89
|
+
@can_generate = true
|
90
|
+
|
91
|
+
@run_beam = method(:seq2seq_run_beam)
|
92
|
+
@get_start_beams = method(:seq2seq_start_beams)
|
93
|
+
@update_beam = method(:seq2seq_update_beam)
|
94
|
+
@forward = method(:seq2seq_forward)
|
95
|
+
|
82
96
|
when MODEL_TYPES[:EncoderDecoder]
|
83
|
-
|
97
|
+
@forward = method(:encoder_forward)
|
98
|
+
|
84
99
|
else
|
85
100
|
@forward = method(:encoder_forward)
|
86
101
|
end
|
@@ -110,20 +125,37 @@ module Informers
|
|
110
125
|
model_type = MODEL_TYPE_MAPPING[model_name]
|
111
126
|
|
112
127
|
if model_type == MODEL_TYPES[:DecoderOnly]
|
113
|
-
|
128
|
+
info = [
|
129
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
130
|
+
construct_session(pretrained_model_name_or_path, options[:model_file_name] || "decoder_model_merged", **options),
|
131
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
|
132
|
+
]
|
114
133
|
|
115
134
|
elsif model_type == MODEL_TYPES[:Seq2Seq] || model_type == MODEL_TYPES[:Vision2Seq]
|
116
|
-
|
135
|
+
info = [
|
136
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
137
|
+
construct_session(pretrained_model_name_or_path, "encoder_model", **options),
|
138
|
+
construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options),
|
139
|
+
Utils::Hub.get_model_json(pretrained_model_name_or_path, "generation_config.json", false, **options)
|
140
|
+
]
|
117
141
|
|
118
142
|
elsif model_type == MODEL_TYPES[:MaskGeneration]
|
119
|
-
|
143
|
+
info = [
|
144
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
145
|
+
construct_session(pretrained_model_name_or_path, "vision_encoder", **options),
|
146
|
+
construct_session(pretrained_model_name_or_path, "prompt_encoder_mask_decoder", **options)
|
147
|
+
]
|
120
148
|
|
121
149
|
elsif model_type == MODEL_TYPES[:EncoderDecoder]
|
122
|
-
|
150
|
+
info = [
|
151
|
+
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
152
|
+
construct_session(pretrained_model_name_or_path, "encoder_model", **options),
|
153
|
+
construct_session(pretrained_model_name_or_path, "decoder_model_merged", **options)
|
154
|
+
]
|
123
155
|
|
124
156
|
else
|
125
157
|
if model_type != MODEL_TYPES[:EncoderOnly]
|
126
|
-
warn "Model type for '#{model_name || config
|
158
|
+
warn "Model type for '#{model_name || config[:model_type]}' not found, assuming encoder-only architecture. Please report this."
|
127
159
|
end
|
128
160
|
info = [
|
129
161
|
AutoConfig.from_pretrained(pretrained_model_name_or_path, **options),
|
@@ -153,8 +185,445 @@ module Informers
|
|
153
185
|
@forward.(model_inputs, **kwargs)
|
154
186
|
end
|
155
187
|
|
188
|
+
def generate(inputs, generation_config = nil, logits_processor = nil, inputs_attention_mask: nil)
|
189
|
+
if !@can_generate
|
190
|
+
model_name = MODEL_CLASS_TO_NAME_MAPPING[self.class]
|
191
|
+
error_message = "The current model class (#{model_name}) is not compatible with `.generate()`, as it doesn't have a language model head."
|
192
|
+
raise Error, error_message
|
193
|
+
end
|
194
|
+
|
195
|
+
if !inputs.is_a?(Array)
|
196
|
+
raise ArgumentError, "`inputs` must be an Array, but is #{inputs.class.name}"
|
197
|
+
end
|
198
|
+
|
199
|
+
if @config[:is_encoder_decoder]
|
200
|
+
# Generating from the encoder outputs
|
201
|
+
input_ids_seq_length = 0
|
202
|
+
else
|
203
|
+
input_ids_seq_length = inputs.length
|
204
|
+
|
205
|
+
# decoder-only
|
206
|
+
if input_ids_seq_length == 0
|
207
|
+
raise Error, "Must supply a non-empty array of input token ids."
|
208
|
+
end
|
209
|
+
end
|
210
|
+
|
211
|
+
# Update generation config with defaults
|
212
|
+
generation_config = get_generation_config(generation_config)
|
213
|
+
|
214
|
+
logits_processor ||= Utils::LogitsProcessorList.new
|
215
|
+
|
216
|
+
# Update logits processor
|
217
|
+
logits_processor = get_logits_processor(
|
218
|
+
generation_config,
|
219
|
+
input_ids_seq_length,
|
220
|
+
logits_processor
|
221
|
+
)
|
222
|
+
|
223
|
+
eos_token_ids = generation_config[:eos_token_id]
|
224
|
+
if !eos_token_ids.nil? && !eos_token_ids.is_a?(Array)
|
225
|
+
eos_token_ids = [eos_token_ids]
|
226
|
+
end
|
227
|
+
|
228
|
+
num_output_tokens = 1
|
229
|
+
max_output_tokens = num_output_tokens + (generation_config[:max_new_tokens] || Float::INFINITY)
|
230
|
+
|
231
|
+
# Only use max length if max_new_tokens is not provided
|
232
|
+
use_max_length = generation_config[:max_length].is_a?(Integer) && generation_config[:max_new_tokens].nil?
|
233
|
+
sampler = Utils::Sampler.get_sampler(generation_config)
|
234
|
+
|
235
|
+
beams = get_start_beams(inputs, generation_config, num_output_tokens, inputs_attention_mask)
|
236
|
+
|
237
|
+
while beams.any? { |x| !x[:done] } && num_output_tokens < max_output_tokens
|
238
|
+
newest_beams = []
|
239
|
+
beams.each do |beam|
|
240
|
+
if beam[:done]
|
241
|
+
# Add this beam back into the pool
|
242
|
+
newest_beams << beam
|
243
|
+
next
|
244
|
+
end
|
245
|
+
if use_max_length && beam[:output_token_ids].length >= generation_config["max_length"]
|
246
|
+
# Set this beam to done and add it back into the pool
|
247
|
+
beam[:done] = true
|
248
|
+
newest_beams << beam
|
249
|
+
next
|
250
|
+
end
|
251
|
+
|
252
|
+
output = run_beam(beam)
|
253
|
+
|
254
|
+
# add attentions/scores to beam only if user requested
|
255
|
+
if generation_config["output_attentions"]
|
256
|
+
add_attentions_to_beam(beam, output)
|
257
|
+
end
|
258
|
+
|
259
|
+
# Logits are of the form [batch_size, out_seq_length, vocab_size]
|
260
|
+
# In most cases, this will be [batch_size, 1, vocab_size]
|
261
|
+
# So, we select the last token's logits:
|
262
|
+
# (equivalent to `logits = outputs.logits[:, -1, :]`)
|
263
|
+
logits = output["logits"].map { |v| v[-1] }
|
264
|
+
|
265
|
+
# Apply logits processor
|
266
|
+
logits_processor.(beam[:output_token_ids], logits)
|
267
|
+
|
268
|
+
sampled_tokens = sampler.(logits)
|
269
|
+
sampled_tokens.each do |new_token_id, log_prob|
|
270
|
+
# use previous beam as a starting point
|
271
|
+
new_beam = beam.dup
|
272
|
+
|
273
|
+
# update new beam
|
274
|
+
update_beam(new_beam, new_token_id)
|
275
|
+
|
276
|
+
new_beam[:score] += log_prob
|
277
|
+
|
278
|
+
if eos_token_ids && eos_token_ids.include?(new_token_id)
|
279
|
+
new_beam[:done] = true
|
280
|
+
end
|
281
|
+
|
282
|
+
newest_beams << new_beam
|
283
|
+
end
|
284
|
+
end
|
285
|
+
num_output_tokens += 1
|
286
|
+
|
287
|
+
# Next, we get the best beams, per ID
|
288
|
+
newest_beams =
|
289
|
+
group_beams(newest_beams).map do |group|
|
290
|
+
group.sort_by { |v| -v[:score] }[0...generation_config["num_beams"]]
|
291
|
+
end
|
292
|
+
|
293
|
+
# Flatten beams
|
294
|
+
beams = newest_beams.flatten(1)
|
295
|
+
|
296
|
+
# Run callback
|
297
|
+
if generation_config["callback_function"]
|
298
|
+
generation_config["callback_function"].(beams)
|
299
|
+
end
|
300
|
+
end
|
301
|
+
|
302
|
+
# TODO: Ensure that we can return non-batched outputs
|
303
|
+
|
304
|
+
grouped_beams = group_beams(beams)
|
305
|
+
|
306
|
+
get_flattened = lambda do |key|
|
307
|
+
grouped_beams.flat_map do |batch|
|
308
|
+
if generation_config["num_return_sequences"] > 1
|
309
|
+
raise Todo
|
310
|
+
else
|
311
|
+
[batch[0][key]]
|
312
|
+
end
|
313
|
+
end
|
314
|
+
end
|
315
|
+
|
316
|
+
sequences = get_flattened.(:output_token_ids) # [1, seqLength]
|
317
|
+
|
318
|
+
if generation_config["return_dict_in_generate"]
|
319
|
+
raise Todo
|
320
|
+
else
|
321
|
+
sequences
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
156
325
|
private
|
157
326
|
|
327
|
+
def get_logits_processor(
|
328
|
+
generation_config,
|
329
|
+
input_ids_seq_length,
|
330
|
+
logits_processor = nil
|
331
|
+
)
|
332
|
+
processors = Utils::LogitsProcessorList.new
|
333
|
+
|
334
|
+
if !generation_config["repetition_penalty"].nil? && generation_config["repetition_penalty"] != 1.0
|
335
|
+
processors.push(Utils::RepetitionPenaltyLogitsProcessor.new(generation_config["repetition_penalty"]))
|
336
|
+
end
|
337
|
+
|
338
|
+
if !generation_config["no_repeat_ngram_size"].nil? && generation_config["no_repeat_ngram_size"] > 0
|
339
|
+
processors.push(Utils::NoRepeatNGramLogitsProcessor.new(generation_config["no_repeat_ngram_size"]))
|
340
|
+
end
|
341
|
+
|
342
|
+
if !generation_config["bad_words_ids"].nil?
|
343
|
+
processors.push(Utils::NoBadWordsLogitsProcessor.new(generation_config["bad_words_ids"], generation_config["eos_token_id"]))
|
344
|
+
end
|
345
|
+
|
346
|
+
if !generation_config["min_length"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_length"] > 0
|
347
|
+
processors.push(Utils::MinLengthLogitsProcessor.new(generation_config["min_length"], generation_config["eos_token_id"]))
|
348
|
+
end
|
349
|
+
|
350
|
+
if !generation_config["min_new_tokens"].nil? && !generation_config["eos_token_id"].nil? && generation_config["min_new_tokens"] > 0
|
351
|
+
processors.push(Utils::MinNewTokensLengthLogitsProcessor.new(
|
352
|
+
input_ids_seq_length,
|
353
|
+
generation_config["min_new_tokens"],
|
354
|
+
generation_config["eos_token_id"]
|
355
|
+
))
|
356
|
+
end
|
357
|
+
|
358
|
+
if !generation_config["forced_bos_token_id"].nil?
|
359
|
+
processors.push(Utils::ForcedBOSTokenLogitsProcessor.new(generation_config["forced_bos_token_id"]))
|
360
|
+
end
|
361
|
+
|
362
|
+
if !generation_config["forced_eos_token_id"].nil?
|
363
|
+
processors.push(Utils::ForcedEOSTokenLogitsProcessor.new(
|
364
|
+
generation_config["max_length"],
|
365
|
+
generation_config["forced_eos_token_id"]
|
366
|
+
))
|
367
|
+
end
|
368
|
+
|
369
|
+
if !generation_config["begin_suppress_tokens"].nil?
|
370
|
+
raise Todo
|
371
|
+
end
|
372
|
+
|
373
|
+
if !generation_config["forced_decoder_ids"].nil?
|
374
|
+
processors.push(Utils::ForceTokensLogitsProcessor.new(generation_config["forced_decoder_ids"]))
|
375
|
+
end
|
376
|
+
|
377
|
+
if !logits_processor.nil?
|
378
|
+
processors.concat(logits_processor)
|
379
|
+
end
|
380
|
+
|
381
|
+
processors
|
382
|
+
end
|
383
|
+
|
384
|
+
def get_generation_config(generation_config)
|
385
|
+
# Create empty generation config (contains defaults)
|
386
|
+
# We pass `@config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
|
387
|
+
gen_config = Utils::GenerationConfig.new(@config.to_h)
|
388
|
+
|
389
|
+
# Apply model's generation config, if it exists
|
390
|
+
if @generation_config
|
391
|
+
gen_config.merge!(@generation_config)
|
392
|
+
end
|
393
|
+
|
394
|
+
# Finally, use any generation config specified by the user
|
395
|
+
# when calling `generate`
|
396
|
+
if !generation_config.nil?
|
397
|
+
gen_config.merge!(generation_config)
|
398
|
+
end
|
399
|
+
|
400
|
+
gen_config
|
401
|
+
end
|
402
|
+
|
403
|
+
def seq2seq_forward(model_inputs)
|
404
|
+
encoder_outputs = model_inputs[:encoder_outputs]
|
405
|
+
past_key_values = model_inputs[:past_key_values]
|
406
|
+
|
407
|
+
if !encoder_outputs
|
408
|
+
# Encoder outputs are not given, so we must compute them.
|
409
|
+
encoder_outputs = encoder_forward(model_inputs)[0]
|
410
|
+
end
|
411
|
+
decoder_feeds = {
|
412
|
+
input_ids: model_inputs[:decoder_input_ids],
|
413
|
+
encoder_hidden_states: encoder_outputs
|
414
|
+
}
|
415
|
+
use_cache_branch = !!past_key_values
|
416
|
+
|
417
|
+
if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
|
418
|
+
decoder_feeds[:use_cache_branch] = [use_cache_branch]
|
419
|
+
end
|
420
|
+
|
421
|
+
if @decoder_merged_session.inputs.map { |v| v[:name] }.include?("encoder_attention_mask")
|
422
|
+
decoder_feeds[:encoder_attention_mask] = model_inputs[:attention_mask]
|
423
|
+
end
|
424
|
+
|
425
|
+
prepare_position_ids(@decoder_merged_session, decoder_feeds, use_cache_branch)
|
426
|
+
add_past_key_values(decoder_feeds, past_key_values)
|
427
|
+
|
428
|
+
decoder_results = session_run(@decoder_merged_session, decoder_feeds)
|
429
|
+
decoder_results = @decoder_merged_session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
|
430
|
+
logits = decoder_results["logits"]
|
431
|
+
past_key_values = get_past_key_values(decoder_results, past_key_values)
|
432
|
+
|
433
|
+
# Get cross attention and/or decoder attentions if they are present
|
434
|
+
attns = get_attentions(decoder_results)
|
435
|
+
|
436
|
+
Seq2SeqLMOutput.new(logits, past_key_values, encoder_outputs, attns["decoder_attentions"], attns["cross_attentions"])
|
437
|
+
end
|
438
|
+
|
439
|
+
def prepare_position_ids(session, feeds, use_cache_branch)
|
440
|
+
if !session.inputs.map { |v| v[:name] }.include?("position_ids")
|
441
|
+
return
|
442
|
+
end
|
443
|
+
|
444
|
+
raise Todo
|
445
|
+
end
|
446
|
+
|
447
|
+
def get_past_key_values(decoder_results, past_key_values)
|
448
|
+
pkvs = {}
|
449
|
+
|
450
|
+
decoder_results.each_key do |name|
|
451
|
+
if name.start_with?("present")
|
452
|
+
new_name = name.sub("present", "past_key_values")
|
453
|
+
|
454
|
+
if past_key_values && name.include?("encoder")
|
455
|
+
# Optimization introduced by optimum to reuse past key values. So, we just replace the constant
|
456
|
+
# outputs with the previous past key values.
|
457
|
+
# https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
|
458
|
+
pkvs[new_name] = past_key_values[new_name]
|
459
|
+
else
|
460
|
+
pkvs[new_name] = decoder_results[name]
|
461
|
+
end
|
462
|
+
end
|
463
|
+
end
|
464
|
+
pkvs
|
465
|
+
end
|
466
|
+
|
467
|
+
def get_attentions(decoder_results)
|
468
|
+
attns = {}
|
469
|
+
|
470
|
+
["cross_attentions", "decoder_attentions"].each do |attn_name|
|
471
|
+
result = []
|
472
|
+
decoder_results.each_key do |name|
|
473
|
+
if name.start_with?(attn_name)
|
474
|
+
index = name.split(".").pop
|
475
|
+
result[index] = decoder_results[name]
|
476
|
+
end
|
477
|
+
end
|
478
|
+
attns[attn_name] = result
|
479
|
+
end
|
480
|
+
attns
|
481
|
+
end
|
482
|
+
|
483
|
+
def add_past_key_values(decoder_feeds, past_key_values)
|
484
|
+
if past_key_values
|
485
|
+
decoder_feeds.merge!(past_key_values)
|
486
|
+
else
|
487
|
+
# TODO support batches (i.e., batch_size > 1)
|
488
|
+
batch_size = 1
|
489
|
+
|
490
|
+
if @config[:is_encoder_decoder] && (!@add_encoder_pkv.nil? ? @add_encoder_pkv : true)
|
491
|
+
_encoder_dims = [batch_size, @num_encoder_heads, 0, @encoder_dim_kv]
|
492
|
+
_decoder_dims = [batch_size, @num_decoder_heads, 0, @decoder_dim_kv]
|
493
|
+
@num_decoder_layers.times do |i|
|
494
|
+
# decoder_feeds["past_key_values.#{i}.encoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
|
495
|
+
# decoder_feeds["past_key_values.#{i}.encoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(encoder_dims, :float)
|
496
|
+
# decoder_feeds["past_key_values.#{i}.decoder.key"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
|
497
|
+
# decoder_feeds["past_key_values.#{i}.decoder.value"] = OnnxRuntime::OrtValue.from_shape_and_type(decoder_dims, :float)
|
498
|
+
end
|
499
|
+
elsif @config[:model_type] == "falcon"
|
500
|
+
raise Todo
|
501
|
+
elsif @config[:multi_query]
|
502
|
+
raise Todo
|
503
|
+
elsif @config[:model_type] == "bloom"
|
504
|
+
raise Todo
|
505
|
+
else
|
506
|
+
_dims = [batch_size, @num_heads, 0, @dim_kv]
|
507
|
+
@num_layers.times do |i|
|
508
|
+
# decoder_feeds["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
|
509
|
+
# decoder_feeds["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type(dims, :float)
|
510
|
+
end
|
511
|
+
end
|
512
|
+
end
|
513
|
+
end
|
514
|
+
|
515
|
+
def seq2seq_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask = nil)
|
516
|
+
beams = []
|
517
|
+
beam_id = 0
|
518
|
+
|
519
|
+
requires_attention_mask = !@requires_attention_mask.nil? ? @requires_attention_mask : true
|
520
|
+
|
521
|
+
# decoder_input_ids == output_token_ids
|
522
|
+
decoder_input_ids =
|
523
|
+
generation_config["decoder_input_ids"] ||
|
524
|
+
generation_config["decoder_start_token_id"] ||
|
525
|
+
generation_config["bos_token_id"] ||
|
526
|
+
generation_config["eos_token_id"]
|
527
|
+
|
528
|
+
if !decoder_input_ids.is_a?(Array)
|
529
|
+
decoder_input_ids = [decoder_input_ids]
|
530
|
+
end
|
531
|
+
|
532
|
+
input_token_ids.each do |tokens|
|
533
|
+
# TODO: Improve
|
534
|
+
# Currently, just add back batch dimension.
|
535
|
+
# In future, allow for true parallel execution
|
536
|
+
tokens = [tokens]
|
537
|
+
|
538
|
+
# Create beam
|
539
|
+
start = {
|
540
|
+
inputs: tokens,
|
541
|
+
encoder_outputs: nil,
|
542
|
+
prev_model_outputs: nil,
|
543
|
+
|
544
|
+
output_token_ids: decoder_input_ids,
|
545
|
+
done: false,
|
546
|
+
score: 0,
|
547
|
+
id: beam_id # assign unique id to beams
|
548
|
+
}
|
549
|
+
beam_id += 1
|
550
|
+
|
551
|
+
if requires_attention_mask
|
552
|
+
start[:attention_mask] = prepare_attention_mask(tokens)
|
553
|
+
end
|
554
|
+
|
555
|
+
beams << start
|
556
|
+
end
|
557
|
+
|
558
|
+
beams
|
559
|
+
end
|
560
|
+
|
561
|
+
def prepare_attention_mask(tokens)
|
562
|
+
# Prepare attention mask
|
563
|
+
pad_token_id = @config["pad_token_id"]
|
564
|
+
eos_token_id = @config["eos_token_id"]
|
565
|
+
if eos_token_id.is_a?(Integer)
|
566
|
+
eos_token_id = [eos_token_id]
|
567
|
+
end
|
568
|
+
|
569
|
+
is_pad_token_in_inputs = !tokens.index(pad_token_id).nil?
|
570
|
+
is_pad_token_not_equal_to_eos_token_id = eos_token_id.nil? || !eos_token_id.include?(pad_token_id)
|
571
|
+
|
572
|
+
if is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id
|
573
|
+
raise Todo
|
574
|
+
else
|
575
|
+
Utils.ones_like(tokens)
|
576
|
+
end
|
577
|
+
end
|
578
|
+
|
579
|
+
def seq2seq_run_beam(beam)
|
580
|
+
input_name = self.class.const_get(:MAIN_INPUT_NAME)
|
581
|
+
|
582
|
+
decoder_input_ids = beam[:output_token_ids]
|
583
|
+
if beam[:prev_model_outputs]
|
584
|
+
# After the first step, `prev_model_outputs` won't be null.
|
585
|
+
# So, we cut decoder_input_ids if past is used
|
586
|
+
decoder_input_ids = [decoder_input_ids[-1]]
|
587
|
+
end
|
588
|
+
|
589
|
+
# 1. Prepare
|
590
|
+
model_inputs = {
|
591
|
+
input_name => beam[:inputs],
|
592
|
+
decoder_input_ids: [decoder_input_ids],
|
593
|
+
encoder_outputs: beam[:encoder_outputs],
|
594
|
+
past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
|
595
|
+
}
|
596
|
+
if beam[:attention_mask]
|
597
|
+
model_inputs[:attention_mask] = beam[:attention_mask]
|
598
|
+
end
|
599
|
+
|
600
|
+
# 2. Run
|
601
|
+
output = @forward.(model_inputs)
|
602
|
+
|
603
|
+
# 3. Update
|
604
|
+
beam[:prev_model_outputs] = output
|
605
|
+
beam[:encoder_outputs] = output[:encoder_outputs]
|
606
|
+
|
607
|
+
output
|
608
|
+
end
|
609
|
+
|
610
|
+
def seq2seq_update_beam(beam, new_token_id)
|
611
|
+
beam[:output_token_ids] += [new_token_id]
|
612
|
+
end
|
613
|
+
|
614
|
+
def group_beams(beams)
|
615
|
+
# Group beams by their ids
|
616
|
+
groups = {}
|
617
|
+
beams.each do |obj|
|
618
|
+
if !groups[obj[:id]]
|
619
|
+
groups[obj[:id]] = [obj]
|
620
|
+
else
|
621
|
+
groups[obj[:id]] << obj
|
622
|
+
end
|
623
|
+
end
|
624
|
+
groups.values
|
625
|
+
end
|
626
|
+
|
158
627
|
def encoder_forward(model_inputs, output_names: nil)
|
159
628
|
encoder_feeds = {}
|
160
629
|
@session.inputs.each do |input|
|
@@ -167,7 +636,96 @@ module Informers
|
|
167
636
|
session_run(@session, encoder_feeds, output_names:)
|
168
637
|
end
|
169
638
|
|
170
|
-
def
|
639
|
+
def decoder_forward(model_inputs)
|
640
|
+
input_ids, past_key_values, attention_mask =
|
641
|
+
model_inputs.values_at(:input_ids, :past_key_values, :attention_mask)
|
642
|
+
decoder_feeds = {
|
643
|
+
input_ids: input_ids,
|
644
|
+
attention_mask: attention_mask || prepare_attention_mask(input_ids)
|
645
|
+
}
|
646
|
+
use_cache_branch = !!past_key_values
|
647
|
+
|
648
|
+
if @session.inputs.map { |v| v[:name] }.include?("use_cache_branch")
|
649
|
+
decoder_feeds[:use_cache_branch] = [use_cache_branch]
|
650
|
+
end
|
651
|
+
|
652
|
+
prepare_position_ids(@session, decoder_feeds, use_cache_branch)
|
653
|
+
|
654
|
+
add_past_key_values(decoder_feeds, past_key_values)
|
655
|
+
|
656
|
+
decoder_results = session_run(@session, decoder_feeds)
|
657
|
+
decoder_results = @session.outputs.map { |v| v[:name] }.zip(decoder_results).to_h
|
658
|
+
|
659
|
+
logits = decoder_results["logits"]
|
660
|
+
|
661
|
+
past_key_values = get_past_key_values(decoder_results, past_key_values)
|
662
|
+
{"logits" => logits, past_key_values: past_key_values}
|
663
|
+
end
|
664
|
+
|
665
|
+
def decoder_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
666
|
+
beams = []
|
667
|
+
|
668
|
+
beam_id = 0
|
669
|
+
input_token_ids.each do |tokens|
|
670
|
+
output_token_ids = tokens.dup
|
671
|
+
|
672
|
+
# TODO: Improve
|
673
|
+
# Currently, just add back batch dimension.
|
674
|
+
# In future, allow for true parallel execution
|
675
|
+
tokens = [tokens]
|
676
|
+
|
677
|
+
if inputs_attention_mask
|
678
|
+
attn_mask = inputs_attention_mask[beam_id]
|
679
|
+
attn_mask = [attn_mask]
|
680
|
+
else
|
681
|
+
attn_mask = prepare_attention_mask(tokens)
|
682
|
+
end
|
683
|
+
|
684
|
+
start = {
|
685
|
+
input: tokens,
|
686
|
+
model_input_ids: tokens,
|
687
|
+
attention_mask: attn_mask,
|
688
|
+
prev_model_outputs: nil,
|
689
|
+
|
690
|
+
output_token_ids: output_token_ids,
|
691
|
+
num_output_tokens: num_output_tokens,
|
692
|
+
|
693
|
+
done: false,
|
694
|
+
score: 0,
|
695
|
+
id: beam_id # assign unique id to beams
|
696
|
+
}
|
697
|
+
beam_id += 1
|
698
|
+
|
699
|
+
beams << start
|
700
|
+
end
|
701
|
+
beams
|
702
|
+
end
|
703
|
+
|
704
|
+
def decoder_run_beam(beam)
|
705
|
+
attn_mask_data = Array.new(beam[:output_token_ids].length, 1)
|
706
|
+
|
707
|
+
# 1. Prepare
|
708
|
+
model_inputs = {
|
709
|
+
input_ids: beam[:model_input_ids],
|
710
|
+
attention_mask: [attn_mask_data],
|
711
|
+
past_key_values: beam[:prev_model_outputs] && beam[:prev_model_outputs][:past_key_values]
|
712
|
+
}
|
713
|
+
|
714
|
+
# 2. Run
|
715
|
+
output = @forward.(model_inputs)
|
716
|
+
|
717
|
+
# 3. Update
|
718
|
+
beam[:prev_model_outputs] = output
|
719
|
+
|
720
|
+
output
|
721
|
+
end
|
722
|
+
|
723
|
+
def decoder_update_beam(beam, new_token_id)
|
724
|
+
beam[:output_token_ids] += [new_token_id]
|
725
|
+
beam[:model_input_ids] = [[new_token_id]]
|
726
|
+
end
|
727
|
+
|
728
|
+
def session_run(session, inputs, output_names: nil)
|
171
729
|
checked_inputs = validate_inputs(session, inputs)
|
172
730
|
begin
|
173
731
|
output = session.run(output_names || @output_names, checked_inputs)
|
@@ -187,6 +745,18 @@ module Informers
|
|
187
745
|
def validate_inputs(session, inputs)
|
188
746
|
inputs
|
189
747
|
end
|
748
|
+
|
749
|
+
def get_start_beams(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
750
|
+
@get_start_beams.(input_token_ids, generation_config, num_output_tokens, inputs_attention_mask)
|
751
|
+
end
|
752
|
+
|
753
|
+
def run_beam(beam)
|
754
|
+
@run_beam.(beam)
|
755
|
+
end
|
756
|
+
|
757
|
+
def update_beam(beam, new_token_id)
|
758
|
+
@update_beam.(beam, new_token_id)
|
759
|
+
end
|
190
760
|
end
|
191
761
|
|
192
762
|
class BertPreTrainedModel < PreTrainedModel
|
@@ -195,6 +765,12 @@ module Informers
|
|
195
765
|
class BertModel < BertPreTrainedModel
|
196
766
|
end
|
197
767
|
|
768
|
+
class BertForMaskedLM < BertPreTrainedModel
|
769
|
+
def call(model_inputs)
|
770
|
+
MaskedLMOutput.new(*super(model_inputs))
|
771
|
+
end
|
772
|
+
end
|
773
|
+
|
198
774
|
class BertForSequenceClassification < BertPreTrainedModel
|
199
775
|
def call(model_inputs)
|
200
776
|
SequenceClassifierOutput.new(*super(model_inputs))
|
@@ -243,6 +819,126 @@ module Informers
|
|
243
819
|
class MPNetModel < MPNetPreTrainedModel
|
244
820
|
end
|
245
821
|
|
822
|
+
class T5PreTrainedModel < PreTrainedModel
|
823
|
+
end
|
824
|
+
|
825
|
+
class T5Model < T5PreTrainedModel
|
826
|
+
end
|
827
|
+
|
828
|
+
class T5ForConditionalGeneration < T5PreTrainedModel
|
829
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
830
|
+
super(config, session)
|
831
|
+
@decoder_merged_session = decoder_merged_session
|
832
|
+
@generation_config = generation_config
|
833
|
+
|
834
|
+
@num_decoder_layers = @config[:num_decoder_layers]
|
835
|
+
@num_decoder_heads = @config[:num_heads]
|
836
|
+
@decoder_dim_kv = @config[:d_kv]
|
837
|
+
|
838
|
+
@num_encoder_layers = @config[:num_layers]
|
839
|
+
@num_encoder_heads = @config[:num_heads]
|
840
|
+
@encoder_dim_kv = @config[:d_kv]
|
841
|
+
end
|
842
|
+
end
|
843
|
+
|
844
|
+
class BartPretrainedModel < PreTrainedModel
|
845
|
+
end
|
846
|
+
|
847
|
+
class BartModel < BartPretrainedModel
|
848
|
+
end
|
849
|
+
|
850
|
+
class BartForConditionalGeneration < BartPretrainedModel
|
851
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
852
|
+
super(config, session)
|
853
|
+
@decoder_merged_session = decoder_merged_session
|
854
|
+
@generation_config = generation_config
|
855
|
+
|
856
|
+
@num_decoder_layers = @config["decoder_layers"]
|
857
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
858
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
859
|
+
|
860
|
+
@num_encoder_layers = @config["encoder_layers"]
|
861
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
862
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads
|
863
|
+
end
|
864
|
+
end
|
865
|
+
|
866
|
+
class BartForSequenceClassification < BartPretrainedModel
|
867
|
+
def call(model_inputs)
|
868
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
869
|
+
end
|
870
|
+
end
|
871
|
+
|
872
|
+
class MBartPreTrainedModel < PreTrainedModel
|
873
|
+
end
|
874
|
+
|
875
|
+
class MBartModel < MBartPreTrainedModel
|
876
|
+
end
|
877
|
+
|
878
|
+
class MBartForCausalLM < MBartPreTrainedModel
|
879
|
+
attr_reader :num_decoder_layers, :num_decoder_heads, :decoder_dim_kv,
|
880
|
+
:num_encoder_layers, :num_encoder_heads, :encoder_dim_kv
|
881
|
+
|
882
|
+
def initialize(config, decoder_merged_session, generation_config)
|
883
|
+
super(config, decoder_merged_session)
|
884
|
+
@generation_config = generation_config
|
885
|
+
|
886
|
+
@num_decoder_layers = @config["decoder_layers"]
|
887
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
888
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
889
|
+
|
890
|
+
@num_encoder_layers = @config["encoder_layers"]
|
891
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
892
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
893
|
+
end
|
894
|
+
end
|
895
|
+
|
896
|
+
class M2M100PreTrainedModel < PreTrainedModel
|
897
|
+
end
|
898
|
+
|
899
|
+
class M2M100Model < M2M100PreTrainedModel
|
900
|
+
end
|
901
|
+
|
902
|
+
class M2M100ForConditionalGeneration < M2M100PreTrainedModel
|
903
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
904
|
+
super(config, session)
|
905
|
+
@decoder_merged_session = decoder_merged_session
|
906
|
+
@generation_config = generation_config
|
907
|
+
|
908
|
+
@num_decoder_layers = @config["decoder_layers"]
|
909
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
910
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
911
|
+
|
912
|
+
@num_encoder_layers = @config["encoder_layers"]
|
913
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
914
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
915
|
+
end
|
916
|
+
end
|
917
|
+
|
918
|
+
class Wav2Vec2PreTrainedModel < PreTrainedModel
|
919
|
+
end
|
920
|
+
|
921
|
+
class Wav2Vec2Model < Wav2Vec2PreTrainedModel
|
922
|
+
end
|
923
|
+
|
924
|
+
class Wav2Vec2ForSequenceClassification < Wav2Vec2PreTrainedModel
|
925
|
+
def call(model_inputs)
|
926
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
927
|
+
end
|
928
|
+
end
|
929
|
+
|
930
|
+
class RobertaPreTrainedModel < PreTrainedModel
|
931
|
+
end
|
932
|
+
|
933
|
+
class RobertaModel < RobertaPreTrainedModel
|
934
|
+
end
|
935
|
+
|
936
|
+
class RobertaForMaskedLM < RobertaPreTrainedModel
|
937
|
+
def call(model_inputs)
|
938
|
+
MaskedLMOutput.new(*super(model_inputs))
|
939
|
+
end
|
940
|
+
end
|
941
|
+
|
246
942
|
class XLMRobertaPreTrainedModel < PreTrainedModel
|
247
943
|
end
|
248
944
|
|
@@ -255,34 +951,351 @@ module Informers
|
|
255
951
|
end
|
256
952
|
end
|
257
953
|
|
954
|
+
class ViTPreTrainedModel < PreTrainedModel
|
955
|
+
end
|
956
|
+
|
957
|
+
class ViTModel < ViTPreTrainedModel
|
958
|
+
end
|
959
|
+
|
960
|
+
class ViTForImageClassification < ViTPreTrainedModel
|
961
|
+
def call(model_inputs)
|
962
|
+
SequenceClassifierOutput.new(*super(model_inputs))
|
963
|
+
end
|
964
|
+
end
|
965
|
+
|
966
|
+
class CLIPPreTrainedModel < PreTrainedModel
|
967
|
+
end
|
968
|
+
|
969
|
+
class CLIPModel < CLIPPreTrainedModel
|
970
|
+
end
|
971
|
+
|
972
|
+
class GPT2PreTrainedModel < PreTrainedModel
|
973
|
+
attr_reader :num_heads, :num_layers, :dim_kv
|
974
|
+
|
975
|
+
def initialize(config, session, generation_config)
|
976
|
+
super(config, session)
|
977
|
+
@generation_config = generation_config
|
978
|
+
|
979
|
+
# config doesn't contain pad_token_id, so we assume it is the eos_token_id
|
980
|
+
@config["pad_token_id"] = @config["eos_token_id"]
|
981
|
+
|
982
|
+
@num_heads = @config["n_head"]
|
983
|
+
@num_layers = @config["n_layer"]
|
984
|
+
@dim_kv = @config["n_embd"] / @num_heads.to_f
|
985
|
+
end
|
986
|
+
end
|
987
|
+
|
988
|
+
class GPT2Model < GPT2PreTrainedModel
|
989
|
+
end
|
990
|
+
|
991
|
+
class GPT2LMHeadModel < GPT2PreTrainedModel
|
992
|
+
end
|
993
|
+
|
994
|
+
class OwlViTPreTrainedModel < PreTrainedModel
|
995
|
+
end
|
996
|
+
|
997
|
+
class OwlViTModel < OwlViTPreTrainedModel
|
998
|
+
end
|
999
|
+
|
1000
|
+
class OwlViTForObjectDetection < OwlViTPreTrainedModel
|
1001
|
+
end
|
1002
|
+
|
1003
|
+
class DetrPreTrainedModel < PreTrainedModel
|
1004
|
+
end
|
1005
|
+
|
1006
|
+
class DetrModel < DetrPreTrainedModel
|
1007
|
+
end
|
1008
|
+
|
1009
|
+
class DetrForObjectDetection < DetrPreTrainedModel
|
1010
|
+
def call(model_inputs)
|
1011
|
+
DetrObjectDetectionOutput.new(*super(model_inputs))
|
1012
|
+
end
|
1013
|
+
end
|
1014
|
+
|
1015
|
+
class DetrForSegmentation < DetrPreTrainedModel
|
1016
|
+
def call(model_inputs)
|
1017
|
+
DetrSegmentationOutput.new(*super(model_inputs))
|
1018
|
+
end
|
1019
|
+
end
|
1020
|
+
|
1021
|
+
class Swin2SRPreTrainedModel < PreTrainedModel
|
1022
|
+
end
|
1023
|
+
|
1024
|
+
class Swin2SRModel < Swin2SRPreTrainedModel
|
1025
|
+
end
|
1026
|
+
|
1027
|
+
class Swin2SRForImageSuperResolution < Swin2SRPreTrainedModel
|
1028
|
+
end
|
1029
|
+
|
1030
|
+
class DPTPreTrainedModel < PreTrainedModel
|
1031
|
+
end
|
1032
|
+
|
1033
|
+
class DPTModel < DPTPreTrainedModel
|
1034
|
+
end
|
1035
|
+
|
1036
|
+
class DPTForDepthEstimation < DPTPreTrainedModel
|
1037
|
+
end
|
1038
|
+
|
1039
|
+
class VisionEncoderDecoderModel < PreTrainedModel
|
1040
|
+
MAIN_INPUT_NAME = :pixel_values
|
1041
|
+
|
1042
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
1043
|
+
super(config, session)
|
1044
|
+
@decoder_merged_session = decoder_merged_session
|
1045
|
+
@generation_config = generation_config
|
1046
|
+
|
1047
|
+
# Extract configs
|
1048
|
+
encoder_config = @config["encoder"]
|
1049
|
+
decoder_config = @config["decoder"]
|
1050
|
+
|
1051
|
+
# Validate encoder
|
1052
|
+
encoder_model_type = encoder_config["model_type"]
|
1053
|
+
encoder_model = MODEL_MAPPING_NAMES_ENCODER_ONLY[encoder_model_type] || MODEL_MAPPING_NAMES_ENCODER_DECODER[encoder_model_type]
|
1054
|
+
if !encoder_model
|
1055
|
+
warn "Model type for encoder '#{encoder_model_type}' not found, assuming encoder-only architecture. Please report this."
|
1056
|
+
end
|
1057
|
+
|
1058
|
+
# Validate decoder
|
1059
|
+
decoder_model = MODEL_WITH_LM_HEAD_MAPPING_NAMES[decoder_config["model_type"]]
|
1060
|
+
if !decoder_model
|
1061
|
+
raise Error, "Unable to construct `VisionEncoderDecoder` due to unsupported decoder: \"#{decoder_config["model_type"]}\""
|
1062
|
+
end
|
1063
|
+
|
1064
|
+
decoder_model_class = decoder_model[1]
|
1065
|
+
decoder = decoder_model_class.new(decoder_config, decoder_merged_session, generation_config)
|
1066
|
+
|
1067
|
+
@add_encoder_pkv = decoder.respond_to?(:num_decoder_layers)
|
1068
|
+
if @add_encoder_pkv
|
1069
|
+
# Decoder is part of an encoder-decoder model
|
1070
|
+
@num_decoder_layers = decoder.num_decoder_layers
|
1071
|
+
@num_decoder_heads = decoder.num_decoder_heads
|
1072
|
+
@decoder_dim_kv = decoder.decoder_dim_kv
|
1073
|
+
|
1074
|
+
@num_encoder_layers = decoder.num_encoder_layers
|
1075
|
+
@num_encoder_heads = decoder.num_encoder_heads
|
1076
|
+
@encoder_dim_kv = decoder.encoder_dim_kv
|
1077
|
+
else
|
1078
|
+
# Decoder is a decoder-only model
|
1079
|
+
@num_layers = decoder.num_layers
|
1080
|
+
@num_heads = decoder.num_heads
|
1081
|
+
@dim_kv = decoder.dim_kv
|
1082
|
+
end
|
1083
|
+
end
|
1084
|
+
end
|
1085
|
+
|
1086
|
+
class DonutSwinPreTrainedModel < PreTrainedModel
|
1087
|
+
end
|
1088
|
+
|
1089
|
+
class DonutSwinModel < DonutSwinPreTrainedModel
|
1090
|
+
end
|
1091
|
+
|
1092
|
+
class WhisperPreTrainedModel < PreTrainedModel
|
1093
|
+
end
|
1094
|
+
|
1095
|
+
class WhisperModel < WhisperPreTrainedModel
|
1096
|
+
end
|
1097
|
+
|
1098
|
+
class WhisperForConditionalGeneration < WhisperPreTrainedModel
|
1099
|
+
REQUIRES_ATTENTION_MASK = false
|
1100
|
+
MAIN_INPUT_NAME = :input_features
|
1101
|
+
|
1102
|
+
def initialize(config, session, decoder_merged_session, generation_config)
|
1103
|
+
super(config, session)
|
1104
|
+
@decoder_merged_session = decoder_merged_session
|
1105
|
+
@generation_config = generation_config
|
1106
|
+
|
1107
|
+
@num_decoder_layers = @config["decoder_layers"]
|
1108
|
+
@num_decoder_heads = @config["decoder_attention_heads"]
|
1109
|
+
@decoder_dim_kv = @config["d_model"] / @num_decoder_heads.to_f
|
1110
|
+
|
1111
|
+
@num_encoder_layers = @config["encoder_layers"]
|
1112
|
+
@num_encoder_heads = @config["encoder_attention_heads"]
|
1113
|
+
@encoder_dim_kv = @config["d_model"] / @num_encoder_heads.to_f
|
1114
|
+
end
|
1115
|
+
|
1116
|
+
def generate(inputs, generation_config = nil, logits_processor = nil)
|
1117
|
+
raise Todo
|
1118
|
+
end
|
1119
|
+
end
|
1120
|
+
|
1121
|
+
class VitsPreTrainedModel < PreTrainedModel
|
1122
|
+
end
|
1123
|
+
|
1124
|
+
class VitsModel < VitsPreTrainedModel
|
1125
|
+
def call(model_inputs)
|
1126
|
+
VitsModelOutput.new(*super(model_inputs))
|
1127
|
+
end
|
1128
|
+
end
|
1129
|
+
|
1130
|
+
class SpeechT5PreTrainedModel < PreTrainedModel
|
1131
|
+
end
|
1132
|
+
|
1133
|
+
class SpeechT5Model < SpeechT5PreTrainedModel
|
1134
|
+
end
|
1135
|
+
|
1136
|
+
class SpeechT5ForSpeechToText < SpeechT5PreTrainedModel
|
1137
|
+
end
|
1138
|
+
|
1139
|
+
class SpeechT5ForTextToSpeech < SpeechT5PreTrainedModel
|
1140
|
+
end
|
1141
|
+
|
1142
|
+
class ClapPreTrainedModel < PreTrainedModel
|
1143
|
+
end
|
1144
|
+
|
1145
|
+
class ClapModel < ClapPreTrainedModel
|
1146
|
+
end
|
1147
|
+
|
258
1148
|
MODEL_MAPPING_NAMES_ENCODER_ONLY = {
|
259
1149
|
"bert" => ["BertModel", BertModel],
|
260
1150
|
"nomic_bert" => ["NomicBertModel", NomicBertModel],
|
261
1151
|
"deberta-v2" => ["DebertaV2Model", DebertaV2Model],
|
262
1152
|
"mpnet" => ["MPNetModel", MPNetModel],
|
263
1153
|
"distilbert" => ["DistilBertModel", DistilBertModel],
|
264
|
-
"
|
1154
|
+
"roberta" => ["RobertaModel", RobertaModel],
|
1155
|
+
"xlm-roberta" => ["XLMRobertaModel", XLMRobertaModel],
|
1156
|
+
"clap" => ["ClapModel", ClapModel],
|
1157
|
+
"clip" => ["CLIPModel", CLIPModel],
|
1158
|
+
"detr" => ["DetrModel", DetrModel],
|
1159
|
+
"vit" => ["ViTModel", ViTModel],
|
1160
|
+
"owlvit" => ["OwlViTModel", OwlViTModel],
|
1161
|
+
"donut-swin" => ["DonutSwinModel", DonutSwinModel]
|
1162
|
+
}
|
1163
|
+
|
1164
|
+
MODEL_MAPPING_NAMES_ENCODER_DECODER = {
|
1165
|
+
"bart" => ["BartModel", BartModel]
|
1166
|
+
}
|
1167
|
+
|
1168
|
+
MODEL_MAPPING_NAMES_DECODER_ONLY = {
|
1169
|
+
}
|
1170
|
+
|
1171
|
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = {
|
1172
|
+
"whisper" => ["WhisperForConditionalGeneration", WhisperForConditionalGeneration]
|
1173
|
+
}
|
1174
|
+
|
1175
|
+
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = {
|
1176
|
+
"speecht5" => ["SpeechT5ForTextToSpeech", SpeechT5ForTextToSpeech]
|
1177
|
+
}
|
1178
|
+
|
1179
|
+
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = {
|
1180
|
+
"vits" => ["VitsModel", VitsModel]
|
265
1181
|
}
|
266
1182
|
|
267
1183
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
|
268
1184
|
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
|
269
1185
|
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
|
270
|
-
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification]
|
1186
|
+
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
|
1187
|
+
"bart" => ["BartForSequenceClassification", BartForSequenceClassification]
|
271
1188
|
}
|
272
1189
|
|
273
1190
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
|
274
1191
|
"bert" => ["BertForTokenClassification", BertForTokenClassification]
|
275
1192
|
}
|
276
1193
|
|
1194
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
|
1195
|
+
"t5" => ["T5ForConditionalGeneration", T5ForConditionalGeneration],
|
1196
|
+
"bart" => ["BartForConditionalGeneration", BartForConditionalGeneration],
|
1197
|
+
"m2m_100" => ["M2M100ForConditionalGeneration", M2M100ForConditionalGeneration]
|
1198
|
+
}
|
1199
|
+
|
1200
|
+
MODEL_WITH_LM_HEAD_MAPPING_NAMES = {
|
1201
|
+
"gpt2" => ["GPT2LMHeadModel", GPT2LMHeadModel],
|
1202
|
+
"mbart" => ["MBartForCausalLM", MBartForCausalLM]
|
1203
|
+
}
|
1204
|
+
|
1205
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
|
1206
|
+
"bert" => ["BertForMaskedLM", BertForMaskedLM],
|
1207
|
+
"roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
|
1208
|
+
}
|
1209
|
+
|
277
1210
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = {
|
278
1211
|
"distilbert" => ["DistilBertForQuestionAnswering", DistilBertForQuestionAnswering]
|
279
1212
|
}
|
280
1213
|
|
1214
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {
|
1215
|
+
"vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
|
1216
|
+
}
|
1217
|
+
|
1218
|
+
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = {
|
1219
|
+
"vision-encoder-decoder" => ["VisionEncoderDecoderModel", VisionEncoderDecoderModel]
|
1220
|
+
}
|
1221
|
+
|
1222
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = {
|
1223
|
+
"vit" => ["ViTForImageClassification", ViTForImageClassification]
|
1224
|
+
}
|
1225
|
+
|
1226
|
+
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = {
|
1227
|
+
"detr" => ["DetrForObjectDetection", DetrForObjectDetection]
|
1228
|
+
}
|
1229
|
+
|
1230
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = {
|
1231
|
+
"owlvit" => ["OwlViTForObjectDetection", OwlViTForObjectDetection]
|
1232
|
+
}
|
1233
|
+
|
1234
|
+
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = {
|
1235
|
+
"detr" => ["DetrForSegmentation", DetrForSegmentation]
|
1236
|
+
}
|
1237
|
+
|
1238
|
+
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = {
|
1239
|
+
}
|
1240
|
+
|
1241
|
+
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = {
|
1242
|
+
}
|
1243
|
+
|
1244
|
+
MODEL_FOR_CTC_MAPPING_NAMES = {
|
1245
|
+
}
|
1246
|
+
|
1247
|
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = {
|
1248
|
+
"wav2vec2" => ["Wav2Vec2ForSequenceClassification", Wav2Vec2ForSequenceClassification]
|
1249
|
+
}
|
1250
|
+
|
1251
|
+
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = {
|
1252
|
+
}
|
1253
|
+
|
1254
|
+
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = {
|
1255
|
+
}
|
1256
|
+
|
1257
|
+
MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = {
|
1258
|
+
}
|
1259
|
+
|
1260
|
+
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = {
|
1261
|
+
"swin2sr" => ["Swin2SRForImageSuperResolution", Swin2SRForImageSuperResolution]
|
1262
|
+
}
|
1263
|
+
|
1264
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = {
|
1265
|
+
"dpt" => ["DPTForDepthEstimation", DPTForDepthEstimation]
|
1266
|
+
}
|
1267
|
+
|
1268
|
+
MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = {
|
1269
|
+
}
|
1270
|
+
|
281
1271
|
MODEL_CLASS_TYPE_MAPPING = [
|
282
1272
|
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES[:EncoderOnly]],
|
1273
|
+
[MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES[:EncoderDecoder]],
|
1274
|
+
[MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES[:DecoderOnly]],
|
283
1275
|
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
284
1276
|
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
285
|
-
[
|
1277
|
+
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
|
1278
|
+
[MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
|
1279
|
+
[MODEL_WITH_LM_HEAD_MAPPING_NAMES, MODEL_TYPES[:DecoderOnly]],
|
1280
|
+
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1281
|
+
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1282
|
+
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES[:Vision2Seq]],
|
1283
|
+
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1284
|
+
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1285
|
+
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1286
|
+
[MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1287
|
+
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1288
|
+
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1289
|
+
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1290
|
+
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1291
|
+
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES[:MaskGeneration]],
|
1292
|
+
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1293
|
+
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1294
|
+
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES[:Seq2Seq]],
|
1295
|
+
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1296
|
+
[MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1297
|
+
[MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]],
|
1298
|
+
[MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES, MODEL_TYPES[:EncoderOnly]]
|
286
1299
|
]
|
287
1300
|
|
288
1301
|
MODEL_CLASS_TYPE_MAPPING.each do |mappings, type|
|
@@ -306,11 +1319,113 @@ module Informers
|
|
306
1319
|
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]
|
307
1320
|
end
|
308
1321
|
|
1322
|
+
class AutoModelForSeq2SeqLM < PretrainedMixin
|
1323
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES]
|
1324
|
+
end
|
1325
|
+
|
1326
|
+
class AutoModelForSpeechSeq2Seq < PretrainedMixin
|
1327
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES]
|
1328
|
+
end
|
1329
|
+
|
1330
|
+
class AutoModelForTextToSpectrogram < PretrainedMixin
|
1331
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES]
|
1332
|
+
end
|
1333
|
+
|
1334
|
+
class AutoModelForTextToWaveform < PretrainedMixin
|
1335
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES]
|
1336
|
+
end
|
1337
|
+
|
1338
|
+
class AutoModelForCausalLM < PretrainedMixin
|
1339
|
+
MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]
|
1340
|
+
end
|
1341
|
+
|
1342
|
+
class AutoModelForMaskedLM < PretrainedMixin
|
1343
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]
|
1344
|
+
end
|
1345
|
+
|
309
1346
|
class AutoModelForQuestionAnswering < PretrainedMixin
|
310
1347
|
MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]
|
311
1348
|
end
|
312
1349
|
|
1350
|
+
class AutoModelForVision2Seq < PretrainedMixin
|
1351
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]
|
1352
|
+
end
|
1353
|
+
|
1354
|
+
class AutoModelForImageClassification < PretrainedMixin
|
1355
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]
|
1356
|
+
end
|
1357
|
+
|
1358
|
+
class AutoModelForImageSegmentation < PretrainedMixin
|
1359
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]
|
1360
|
+
end
|
1361
|
+
|
1362
|
+
class AutoModelForSemanticSegmentation < PretrainedMixin
|
1363
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES]
|
1364
|
+
end
|
1365
|
+
|
1366
|
+
class AutoModelForObjectDetection < PretrainedMixin
|
1367
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]
|
1368
|
+
end
|
1369
|
+
|
1370
|
+
class AutoModelForZeroShotObjectDetection < PretrainedMixin
|
1371
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]
|
1372
|
+
end
|
1373
|
+
|
1374
|
+
class AutoModelForMaskGeneration < PretrainedMixin
|
1375
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES]
|
1376
|
+
end
|
1377
|
+
|
1378
|
+
class AutoModelForCTC < PretrainedMixin
|
1379
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_CTC_MAPPING_NAMES]
|
1380
|
+
end
|
1381
|
+
|
1382
|
+
class AutoModelForAudioClassification < PretrainedMixin
|
1383
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]
|
1384
|
+
end
|
1385
|
+
|
1386
|
+
class AutoModelForXVector < PretrainedMixin
|
1387
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]
|
1388
|
+
end
|
1389
|
+
|
1390
|
+
class AutoModelForAudioFrameClassification < PretrainedMixin
|
1391
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES]
|
1392
|
+
end
|
1393
|
+
|
1394
|
+
class AutoModelForDocumentQuestionAnswering < PretrainedMixin
|
1395
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]
|
1396
|
+
end
|
1397
|
+
|
1398
|
+
class AutoModelForImageMatting < PretrainedMixin
|
1399
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES]
|
1400
|
+
end
|
1401
|
+
|
1402
|
+
class AutoModelForImageToImage < PretrainedMixin
|
1403
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES]
|
1404
|
+
end
|
1405
|
+
|
1406
|
+
class AutoModelForDepthEstimation < PretrainedMixin
|
1407
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES]
|
1408
|
+
end
|
1409
|
+
|
1410
|
+
class AutoModelForImageFeatureExtraction < PretrainedMixin
|
1411
|
+
MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES]
|
1412
|
+
end
|
1413
|
+
|
313
1414
|
class ModelOutput
|
1415
|
+
def [](key)
|
1416
|
+
instance_variable_get("@#{key}")
|
1417
|
+
end
|
1418
|
+
end
|
1419
|
+
|
1420
|
+
class Seq2SeqLMOutput < ModelOutput
|
1421
|
+
def initialize(logits, past_key_values, encoder_outputs, decoder_attentions = nil, cross_attentions = nil)
|
1422
|
+
super()
|
1423
|
+
@logits = logits
|
1424
|
+
@past_key_values = past_key_values
|
1425
|
+
@encoder_outputs = encoder_outputs
|
1426
|
+
@decoder_attentions = decoder_attentions
|
1427
|
+
@cross_attentions = cross_attentions
|
1428
|
+
end
|
314
1429
|
end
|
315
1430
|
|
316
1431
|
class SequenceClassifierOutput < ModelOutput
|
@@ -331,6 +1446,15 @@ module Informers
|
|
331
1446
|
end
|
332
1447
|
end
|
333
1448
|
|
1449
|
+
class MaskedLMOutput < ModelOutput
|
1450
|
+
attr_reader :logits
|
1451
|
+
|
1452
|
+
def initialize(logits)
|
1453
|
+
super()
|
1454
|
+
@logits = logits
|
1455
|
+
end
|
1456
|
+
end
|
1457
|
+
|
334
1458
|
class QuestionAnsweringModelOutput < ModelOutput
|
335
1459
|
attr_reader :start_logits, :end_logits
|
336
1460
|
|
@@ -340,4 +1464,25 @@ module Informers
|
|
340
1464
|
@end_logits = end_logits
|
341
1465
|
end
|
342
1466
|
end
|
1467
|
+
|
1468
|
+
class DetrObjectDetectionOutput < ModelOutput
|
1469
|
+
attr_reader :logits, :pred_boxes
|
1470
|
+
|
1471
|
+
def initialize(logits, pred_boxes)
|
1472
|
+
super()
|
1473
|
+
@logits = logits
|
1474
|
+
@pred_boxes = pred_boxes
|
1475
|
+
end
|
1476
|
+
end
|
1477
|
+
|
1478
|
+
class DetrSegmentationOutput < ModelOutput
|
1479
|
+
attr_reader :logits, :pred_boxes, :pred_masks
|
1480
|
+
|
1481
|
+
def initialize(logits, pred_boxes, pred_masks)
|
1482
|
+
super()
|
1483
|
+
@logits = logits
|
1484
|
+
@pred_boxes = pred_boxes
|
1485
|
+
@pred_masks = pred_masks
|
1486
|
+
end
|
1487
|
+
end
|
343
1488
|
end
|