optimum-rbln 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +6 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
- optimum/rbln/diffusers/models/controlnet.py +93 -23
- optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
- optimum/rbln/diffusers/pipelines/__init__.py +7 -2
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
- optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
- optimum/rbln/modeling_base.py +36 -3
- optimum/rbln/modeling_seq2seq.py +19 -4
- optimum/rbln/transformers/generation/__init__.py +1 -0
- optimum/rbln/transformers/generation/streamers.py +17 -0
- optimum/rbln/transformers/generation/utils.py +399 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
- optimum/rbln/transformers/models/llama/modeling_llama.py +63 -45
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/RECORD +29 -25
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,6 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
import warnings
|
27
26
|
from pathlib import Path
|
28
27
|
from tempfile import TemporaryDirectory
|
29
28
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
@@ -32,19 +31,13 @@ import rebel
|
|
32
31
|
import torch
|
33
32
|
from optimum.exporters import TasksManager
|
34
33
|
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
35
|
-
from transformers.generation.logits_process import LogitsProcessorList
|
36
|
-
from transformers.generation.stopping_criteria import (
|
37
|
-
StoppingCriteriaList,
|
38
|
-
validate_stopping_criteria,
|
39
|
-
)
|
40
|
-
from transformers.generation.streamers import BaseStreamer
|
41
|
-
from transformers.generation.utils import SampleDecoderOnlyOutput
|
42
34
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
43
35
|
|
44
36
|
from ....modeling_base import RBLNBaseModel
|
45
37
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
46
38
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
47
39
|
from ....utils.save_utils import maybe_save_preprocessors
|
40
|
+
from ...generation.utils import RBLNGenerationMixin
|
48
41
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
49
42
|
|
50
43
|
|
@@ -66,7 +59,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
66
59
|
return Seq2SeqLMOutput(logits=logits)
|
67
60
|
|
68
61
|
|
69
|
-
class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
62
|
+
class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
70
63
|
"""
|
71
64
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
72
65
|
embeddings).
|
@@ -135,6 +128,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
135
128
|
subfolder: str = "",
|
136
129
|
local_files_only: bool = False,
|
137
130
|
trust_remote_code: bool = False,
|
131
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
138
132
|
**kwargs,
|
139
133
|
) -> "RBLNGPT2LMHeadModel":
|
140
134
|
"""
|
@@ -144,8 +138,16 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
144
138
|
if task is None:
|
145
139
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
146
140
|
|
147
|
-
|
148
|
-
|
141
|
+
if model_save_dir is None:
|
142
|
+
save_dir = TemporaryDirectory()
|
143
|
+
save_dir_path = Path(save_dir.name)
|
144
|
+
else:
|
145
|
+
save_dir = model_save_dir
|
146
|
+
if isinstance(save_dir, TemporaryDirectory):
|
147
|
+
save_dir_path = Path(model_save_dir.name)
|
148
|
+
else:
|
149
|
+
save_dir_path = Path(model_save_dir)
|
150
|
+
save_dir_path.mkdir(exist_ok=True)
|
149
151
|
|
150
152
|
kwargs.update(
|
151
153
|
{
|
@@ -264,8 +266,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
264
266
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
265
267
|
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
266
268
|
|
267
|
-
if rbln_batch_size is None
|
268
|
-
rbln_batch_size = 1
|
269
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
269
270
|
|
270
271
|
def get_input_info(query_length):
|
271
272
|
return [
|
@@ -320,6 +321,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
320
321
|
self.prompt_ids = input_ids
|
321
322
|
self.rightpad_max_len = cur_len
|
322
323
|
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
324
|
+
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
|
323
325
|
|
324
326
|
if cur_len % self.prefill_chunk_size == 0:
|
325
327
|
pad_len = 0
|
@@ -329,12 +331,12 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
329
331
|
attention_mask = self.prefill_attention_mask.clone()
|
330
332
|
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
331
333
|
|
332
|
-
query_length = prompt_min_len
|
334
|
+
query_length = prompt_min_len.item()
|
333
335
|
else:
|
334
336
|
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
335
337
|
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
336
338
|
attention_mask[:, :, :, : cache_position + 1] = 1
|
337
|
-
input_ids = input_ids[:,
|
339
|
+
input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
|
338
340
|
query_length = 1
|
339
341
|
|
340
342
|
model_inputs = {
|
@@ -357,25 +359,23 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
357
359
|
query_length: Optional[torch.Tensor] = None,
|
358
360
|
**kwargs,
|
359
361
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
360
|
-
|
361
362
|
if past_key_values is not None:
|
362
363
|
past_key_values += query_length
|
363
364
|
|
364
365
|
if cache_position == 0:
|
365
|
-
for
|
366
|
-
sliced_input_ids = input_ids[:,
|
367
|
-
attention_mask[:, :, :, :
|
368
|
-
attention_mask[:, :, :,
|
366
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
367
|
+
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
368
|
+
attention_mask[:, :, :, :step] = 1
|
369
|
+
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
369
370
|
|
370
371
|
output = self.prefill_decoder(
|
371
372
|
input_ids=sliced_input_ids.contiguous(),
|
372
373
|
attention_mask=attention_mask.contiguous(),
|
373
|
-
cache_position=cache_position,
|
374
|
+
cache_position=cache_position + step,
|
374
375
|
)
|
375
|
-
query_length -= self.prefill_chunk_size
|
376
|
-
cache_position += self.prefill_chunk_size
|
377
376
|
|
378
|
-
|
377
|
+
idx = query_length % self.prefill_chunk_size - 1
|
378
|
+
output = output.logits[:, idx].unsqueeze(1)
|
379
379
|
|
380
380
|
else:
|
381
381
|
output = self.decoder(
|
@@ -389,312 +389,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
|
|
389
389
|
|
390
390
|
def __repr__(self):
|
391
391
|
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
392
|
-
|
393
|
-
# call 'greedy_search` directly is deprecated and removed in v4.41.
|
394
|
-
def greedy_search(self, *args, **kwargs):
|
395
|
-
return self._greedy_search(*args, **kwargs)
|
396
|
-
|
397
|
-
def _greedy_search(
|
398
|
-
self,
|
399
|
-
input_ids: torch.LongTensor,
|
400
|
-
logits_processor: Optional[LogitsProcessorList] = None,
|
401
|
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
402
|
-
max_length: Optional[int] = None,
|
403
|
-
pad_token_id: Optional[int] = None,
|
404
|
-
eos_token_id: Optional[Union[int, List[int]]] = None,
|
405
|
-
output_logits: Optional[bool] = None,
|
406
|
-
return_dict_in_generate: Optional[bool] = None,
|
407
|
-
streamer: Optional["BaseStreamer"] = None,
|
408
|
-
**model_kwargs,
|
409
|
-
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
410
|
-
|
411
|
-
# init values
|
412
|
-
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
413
|
-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
414
|
-
|
415
|
-
if max_length is not None:
|
416
|
-
warnings.warn(
|
417
|
-
"`max_length` is deprecated in this function, use"
|
418
|
-
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
419
|
-
UserWarning,
|
420
|
-
)
|
421
|
-
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
422
|
-
|
423
|
-
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
424
|
-
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
425
|
-
if isinstance(eos_token_id, int):
|
426
|
-
eos_token_id = [eos_token_id]
|
427
|
-
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
428
|
-
|
429
|
-
return_dict_in_generate = (
|
430
|
-
return_dict_in_generate
|
431
|
-
if return_dict_in_generate is not None
|
432
|
-
else self.generation_config.return_dict_in_generate
|
433
|
-
)
|
434
|
-
|
435
|
-
# init attention / hidden states / scores tuples
|
436
|
-
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
437
|
-
|
438
|
-
# keep track of which sequences are already finished
|
439
|
-
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
440
|
-
|
441
|
-
this_peer_finished = False # used by synced_gpus only
|
442
|
-
|
443
|
-
while True:
|
444
|
-
# prepare model inputs
|
445
|
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
446
|
-
|
447
|
-
# forward pass to get next token
|
448
|
-
outputs = self(
|
449
|
-
**model_inputs,
|
450
|
-
return_dict=True,
|
451
|
-
)
|
452
|
-
next_token_logits = outputs.logits[:, -1, :]
|
453
|
-
|
454
|
-
# pre-process distribution
|
455
|
-
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
456
|
-
|
457
|
-
# Store scores, attentions and hidden_states when required
|
458
|
-
if return_dict_in_generate:
|
459
|
-
if output_logits:
|
460
|
-
raw_logits += (next_token_logits,)
|
461
|
-
|
462
|
-
# argmax
|
463
|
-
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
464
|
-
|
465
|
-
# finished sentences should have their next token be a padding token
|
466
|
-
if eos_token_id is not None:
|
467
|
-
if pad_token_id is None:
|
468
|
-
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
469
|
-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
470
|
-
|
471
|
-
########################################################################################################
|
472
|
-
# thkim change for right-padding batch
|
473
|
-
# if min_input_len <= update_idx < max_input_len
|
474
|
-
# update validate input_ids[:,update_idx]
|
475
|
-
# TODO : raw_logits contains dummy next_token's logits
|
476
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
477
|
-
if update_idx < self.rightpad_max_len:
|
478
|
-
# update exist input_ids rather than concat
|
479
|
-
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
480
|
-
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
481
|
-
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
482
|
-
|
483
|
-
# dummy next_token -> pad_token_id for streamer
|
484
|
-
# in order to skip by 'skip_special_tokens = True"
|
485
|
-
dummy_indices = ~valid_indices
|
486
|
-
next_tokens[dummy_indices] = pad_token_id
|
487
|
-
else:
|
488
|
-
############################################END#########################################################
|
489
|
-
# update generated ids, model inputs, and length for next step
|
490
|
-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
491
|
-
|
492
|
-
model_kwargs = self._update_model_kwargs_for_generation(
|
493
|
-
outputs,
|
494
|
-
model_kwargs,
|
495
|
-
is_encoder_decoder=self.config.is_encoder_decoder,
|
496
|
-
)
|
497
|
-
|
498
|
-
if streamer is not None:
|
499
|
-
streamer.put(next_tokens.cpu())
|
500
|
-
|
501
|
-
# if eos_token was found in one sentence, set sentence to finished
|
502
|
-
if eos_token_id_tensor is not None:
|
503
|
-
####################################################################
|
504
|
-
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
505
|
-
if hasattr(self, "rightpad_max_len"):
|
506
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
507
|
-
if update_idx < self.rightpad_max_len:
|
508
|
-
next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
|
509
|
-
######################################################################
|
510
|
-
unfinished_sequences = unfinished_sequences.mul(
|
511
|
-
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
512
|
-
)
|
513
|
-
|
514
|
-
# stop when each sentence is finished
|
515
|
-
if unfinished_sequences.max() == 0:
|
516
|
-
this_peer_finished = True
|
517
|
-
|
518
|
-
# stop if we exceed the maximum length
|
519
|
-
# thkim : backward compatibility bool vs torch.BoolTensor
|
520
|
-
is_stop = stopping_criteria(input_ids, None)
|
521
|
-
if isinstance(is_stop, torch.BoolTensor):
|
522
|
-
is_stop = torch.all(is_stop)
|
523
|
-
if is_stop:
|
524
|
-
this_peer_finished = True
|
525
|
-
|
526
|
-
if this_peer_finished:
|
527
|
-
break
|
528
|
-
|
529
|
-
if streamer is not None:
|
530
|
-
streamer.end()
|
531
|
-
|
532
|
-
if return_dict_in_generate:
|
533
|
-
return SampleDecoderOnlyOutput(
|
534
|
-
sequences=input_ids,
|
535
|
-
logits=raw_logits,
|
536
|
-
)
|
537
|
-
else:
|
538
|
-
return input_ids
|
539
|
-
|
540
|
-
# call 'sample` directly is deprecated and removed in v4.41.
|
541
|
-
def sample(self, *args, **kwargs):
|
542
|
-
return self._sample(*args, **kwargs)
|
543
|
-
|
544
|
-
def _sample(
|
545
|
-
self,
|
546
|
-
input_ids: torch.LongTensor,
|
547
|
-
logits_processor: Optional[LogitsProcessorList] = None,
|
548
|
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
549
|
-
logits_warper: Optional[LogitsProcessorList] = None,
|
550
|
-
max_length: Optional[int] = None,
|
551
|
-
pad_token_id: Optional[int] = None,
|
552
|
-
eos_token_id: Optional[Union[int, List[int]]] = None,
|
553
|
-
output_attentions: Optional[bool] = None,
|
554
|
-
output_hidden_states: Optional[bool] = None,
|
555
|
-
output_scores: Optional[bool] = None,
|
556
|
-
output_logits: Optional[bool] = None,
|
557
|
-
return_dict_in_generate: Optional[bool] = None,
|
558
|
-
synced_gpus: bool = False,
|
559
|
-
streamer: Optional["BaseStreamer"] = None,
|
560
|
-
**model_kwargs,
|
561
|
-
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
|
562
|
-
# init values
|
563
|
-
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
564
|
-
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
565
|
-
|
566
|
-
if max_length is not None:
|
567
|
-
warnings.warn(
|
568
|
-
"`max_length` is deprecated in this function, use"
|
569
|
-
" `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
570
|
-
UserWarning,
|
571
|
-
)
|
572
|
-
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
573
|
-
|
574
|
-
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
575
|
-
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
576
|
-
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
577
|
-
|
578
|
-
if isinstance(eos_token_id, int):
|
579
|
-
eos_token_id = [eos_token_id]
|
580
|
-
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
581
|
-
|
582
|
-
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
583
|
-
output_logits = output_logits if output_logits is not None else False
|
584
|
-
|
585
|
-
# init attention / hidden states / scores tuples
|
586
|
-
scores = () if (return_dict_in_generate and output_scores) else None
|
587
|
-
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
588
|
-
|
589
|
-
# keep track of which sequences are already finished
|
590
|
-
batch_size, cur_len = input_ids.shape
|
591
|
-
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
592
|
-
this_peer_finished = False
|
593
|
-
|
594
|
-
# model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
|
595
|
-
|
596
|
-
while True:
|
597
|
-
# prepare model inputs
|
598
|
-
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
599
|
-
|
600
|
-
# forward pass to get next token
|
601
|
-
outputs = self(
|
602
|
-
**model_inputs,
|
603
|
-
return_dict=True,
|
604
|
-
output_attentions=output_attentions,
|
605
|
-
output_hidden_states=output_hidden_states,
|
606
|
-
)
|
607
|
-
|
608
|
-
next_token_logits = outputs.logits[:, -1, :]
|
609
|
-
|
610
|
-
# pre-process distribution
|
611
|
-
next_token_scores = logits_processor(input_ids, next_token_logits)
|
612
|
-
next_token_scores = logits_warper(input_ids, next_token_scores)
|
613
|
-
|
614
|
-
# Store scores, attentions and hidden_states when required
|
615
|
-
if return_dict_in_generate:
|
616
|
-
if output_scores:
|
617
|
-
scores += (next_token_scores,)
|
618
|
-
if output_logits:
|
619
|
-
raw_logits += (next_token_logits,)
|
620
|
-
|
621
|
-
# sample
|
622
|
-
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
623
|
-
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
624
|
-
|
625
|
-
# finished sentences should have their next token be a padding token
|
626
|
-
if eos_token_id is not None:
|
627
|
-
if pad_token_id is None:
|
628
|
-
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
629
|
-
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
630
|
-
|
631
|
-
########################################################################################################
|
632
|
-
# thkim change for right-padding batch
|
633
|
-
# if min_input_len <= update_idx < max_input_len
|
634
|
-
# update validate input_ids[:,update_idx]
|
635
|
-
# TODO : raw_logits contains dummy next_token's logits
|
636
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
637
|
-
if update_idx < self.rightpad_max_len:
|
638
|
-
# update exist input_ids rather than concat
|
639
|
-
valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
|
640
|
-
input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
|
641
|
-
model_kwargs["attention_mask"][valid_indices, update_idx] = 1
|
642
|
-
|
643
|
-
# dummy next_token -> pad_token_id for streamer
|
644
|
-
# in order to skip by 'skip_special_tokens = True"
|
645
|
-
dummy_indices = ~valid_indices
|
646
|
-
next_tokens[dummy_indices] = pad_token_id
|
647
|
-
else:
|
648
|
-
############################################END#########################################################
|
649
|
-
|
650
|
-
# update generated ids, model inputs, and length for next step
|
651
|
-
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
652
|
-
|
653
|
-
model_kwargs = self._update_model_kwargs_for_generation(
|
654
|
-
outputs,
|
655
|
-
model_kwargs,
|
656
|
-
is_encoder_decoder=self.config.is_encoder_decoder,
|
657
|
-
)
|
658
|
-
|
659
|
-
if streamer is not None:
|
660
|
-
streamer.put(next_tokens.cpu())
|
661
|
-
|
662
|
-
# if eos_token was found in one sentence, set sentence to finished
|
663
|
-
if eos_token_id_tensor is not None:
|
664
|
-
####################################################################
|
665
|
-
# thkim : to do not finish sequence of dummy_decoder of right_padding
|
666
|
-
if hasattr(self, "rightpad_max_len"):
|
667
|
-
update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
|
668
|
-
if update_idx < self.rightpad_max_len:
|
669
|
-
next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
|
670
|
-
######################################################################
|
671
|
-
unfinished_sequences = unfinished_sequences.mul(
|
672
|
-
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
673
|
-
)
|
674
|
-
|
675
|
-
# stop when each sentence is finished
|
676
|
-
if unfinished_sequences.max() == 0:
|
677
|
-
this_peer_finished = True
|
678
|
-
|
679
|
-
# stop if we exceed the maximum length
|
680
|
-
# thkim : backward compatibility bool vs list[bool]
|
681
|
-
is_stop = stopping_criteria(input_ids, None)
|
682
|
-
if isinstance(is_stop, torch.BoolTensor):
|
683
|
-
is_stop = torch.all(is_stop)
|
684
|
-
if is_stop:
|
685
|
-
this_peer_finished = True
|
686
|
-
|
687
|
-
if this_peer_finished:
|
688
|
-
break
|
689
|
-
|
690
|
-
if streamer is not None:
|
691
|
-
streamer.end()
|
692
|
-
|
693
|
-
if return_dict_in_generate:
|
694
|
-
return SampleDecoderOnlyOutput(
|
695
|
-
sequences=input_ids,
|
696
|
-
scores=scores,
|
697
|
-
logits=raw_logits,
|
698
|
-
)
|
699
|
-
else:
|
700
|
-
return input_ids
|
@@ -34,6 +34,7 @@ from optimum.exporters import TasksManager
|
|
34
34
|
from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
|
35
35
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
36
36
|
|
37
|
+
from ...generation.utils import RBLNGenerationMixin
|
37
38
|
from ....modeling_base import RBLNBaseModel
|
38
39
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
39
40
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
@@ -75,7 +76,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
75
76
|
return logits
|
76
77
|
|
77
78
|
|
78
|
-
class RBLNLlamaForCausalLM(RBLNBaseModel):
|
79
|
+
class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
|
79
80
|
"""
|
80
81
|
The Llama Model transformer with a language modeling head (linear layer) on top.
|
81
82
|
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -91,7 +92,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
91
92
|
auto_model_class = AutoModelForCausalLM
|
92
93
|
|
93
94
|
def __post_init__(self, **kwargs):
|
94
|
-
|
95
95
|
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
96
96
|
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
97
97
|
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
@@ -106,6 +106,7 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
106
106
|
self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
|
107
107
|
self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
|
108
108
|
self.past_cached_length = 0
|
109
|
+
self.right_padding = True
|
109
110
|
|
110
111
|
@classmethod
|
111
112
|
@torch.no_grad()
|
@@ -120,14 +121,23 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
120
121
|
subfolder: str = "",
|
121
122
|
local_files_only: bool = False,
|
122
123
|
trust_remote_code: bool = False,
|
124
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
123
125
|
**kwargs,
|
124
126
|
) -> "RBLNLlamaForCausalLM":
|
125
127
|
task = kwargs.pop("task", None)
|
126
128
|
if task is None:
|
127
129
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
128
130
|
|
129
|
-
|
130
|
-
|
131
|
+
if model_save_dir is None:
|
132
|
+
save_dir = TemporaryDirectory()
|
133
|
+
save_dir_path = Path(save_dir.name)
|
134
|
+
else:
|
135
|
+
save_dir = model_save_dir
|
136
|
+
if isinstance(save_dir, TemporaryDirectory):
|
137
|
+
save_dir_path = Path(model_save_dir.name)
|
138
|
+
else:
|
139
|
+
save_dir_path = Path(model_save_dir)
|
140
|
+
save_dir_path.mkdir(exist_ok=True)
|
131
141
|
|
132
142
|
def update_configs(kwargs):
|
133
143
|
hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
|
@@ -245,6 +255,7 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
245
255
|
prefill_chunk_size = 128
|
246
256
|
if rbln_max_seq_len is None:
|
247
257
|
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
258
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
248
259
|
|
249
260
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
250
261
|
meta["rbln_batch_size"] = rbln_batch_size
|
@@ -321,23 +332,47 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
321
332
|
|
322
333
|
# args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
|
323
334
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
324
|
-
batch_size,
|
335
|
+
batch_size, cur_len = input_ids.shape
|
325
336
|
past_cached_length = past_key_values
|
326
|
-
query_length = hf_input_length - past_cached_length
|
327
337
|
|
328
338
|
# In greedy decoding
|
329
|
-
if
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
339
|
+
if past_cached_length == 0:
|
340
|
+
|
341
|
+
# padding with prefill_chunk_size
|
342
|
+
# TODO left padding + left padding has issue on stoppingcriteria(max_len)
|
343
|
+
if cur_len % self.prefill_chunk_size != 0:
|
344
|
+
pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
|
345
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
|
346
|
+
|
347
|
+
# padding_side
|
348
|
+
if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
|
349
|
+
self.right_padding = False
|
350
|
+
|
351
|
+
if self.right_padding:
|
352
|
+
self.rightpad_max_len = cur_len
|
353
|
+
prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
|
354
|
+
self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
|
355
|
+
query_length = prompt_min_len.item()
|
356
|
+
else:
|
357
|
+
query_length = cur_len - past_cached_length
|
358
|
+
self.prompt_length = query_length
|
359
|
+
self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
|
360
|
+
|
361
|
+
attention_mask = self.prefill_attention_mask.clone()
|
335
362
|
cache_position = torch.tensor(0, dtype=torch.int32)
|
363
|
+
|
336
364
|
else:
|
337
|
-
|
338
|
-
|
365
|
+
if self.right_padding:
|
366
|
+
attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
|
367
|
+
attention_mask[:, :, :, : past_cached_length + 1] = 1
|
368
|
+
input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
|
369
|
+
else:
|
370
|
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
|
371
|
+
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
372
|
+
input_ids = input_ids[:, -1:]
|
373
|
+
|
339
374
|
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
340
|
-
|
375
|
+
query_length = 1
|
341
376
|
|
342
377
|
model_inputs = {
|
343
378
|
"input_ids": input_ids,
|
@@ -358,43 +393,26 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
|
|
358
393
|
query_length: Optional[torch.Tensor] = None,
|
359
394
|
**kwargs,
|
360
395
|
) -> Tuple[torch.FloatTensor]:
|
396
|
+
|
361
397
|
if past_key_values is not None:
|
362
398
|
past_key_values += query_length
|
363
399
|
|
364
400
|
# prefill_decoder
|
365
401
|
if cache_position == 0:
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
attention_mask[:, :, :, :
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
sliced_input_ids,
|
375
|
-
attention_mask,
|
376
|
-
cache_position,
|
402
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
403
|
+
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
404
|
+
attention_mask[:, :, :, :step] = 1
|
405
|
+
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
406
|
+
if not self.right_padding:
|
407
|
+
attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
|
408
|
+
|
409
|
+
outputs = self.prefill_decoder(
|
410
|
+
input_ids=sliced_input_ids.contiguous(),
|
411
|
+
attention_mask=attention_mask.contiguous(),
|
412
|
+
cache_position=cache_position + step,
|
377
413
|
)
|
378
|
-
|
379
|
-
query_length -= self.prefill_chunk_size
|
380
|
-
cache_position += self.prefill_chunk_size
|
381
|
-
|
382
|
-
# prepare input_ids & attention_mask
|
383
|
-
last_input_ids = input_ids[:, cache_position : cache_position + query_length]
|
384
|
-
last_input_ids = torch.nn.functional.pad(last_input_ids, (0, self.prefill_chunk_size - query_length))
|
385
|
-
|
386
|
-
attention_mask[:, :, :, :cache_position] = 1
|
387
|
-
mask_slice = self.causal_mask[:, :, :query_length, :query_length]
|
388
|
-
attention_mask[:, :, :query_length, cache_position : cache_position + query_length] = mask_slice
|
389
|
-
attention_mask[:, :, :, : self.prompt_length] *= self.prompt_attn_mask[:, :, :, :]
|
390
|
-
|
391
|
-
outputs = self.prefill_decoder(
|
392
|
-
last_input_ids.contiguous(),
|
393
|
-
attention_mask.contiguous(),
|
394
|
-
cache_position,
|
395
|
-
)
|
414
|
+
outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
396
415
|
|
397
|
-
outputs = outputs[:, query_length - 1].unsqueeze(1)
|
398
416
|
# decoder
|
399
417
|
else:
|
400
418
|
outputs = self.decoder(
|
@@ -163,6 +163,7 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
163
163
|
subfolder: str = "",
|
164
164
|
local_files_only: bool = False,
|
165
165
|
trust_remote_code: bool = False,
|
166
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
166
167
|
**kwargs,
|
167
168
|
) -> "RBLNWhisperForConditionalGeneration":
|
168
169
|
"""
|
@@ -172,8 +173,16 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
172
173
|
if task is None:
|
173
174
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
174
175
|
|
175
|
-
|
176
|
-
|
176
|
+
if model_save_dir is None:
|
177
|
+
save_dir = TemporaryDirectory()
|
178
|
+
save_dir_path = Path(save_dir.name)
|
179
|
+
else:
|
180
|
+
save_dir = model_save_dir
|
181
|
+
if isinstance(save_dir, TemporaryDirectory):
|
182
|
+
save_dir_path = Path(model_save_dir.name)
|
183
|
+
else:
|
184
|
+
save_dir_path = Path(model_save_dir)
|
185
|
+
save_dir_path.mkdir(exist_ok=True)
|
177
186
|
|
178
187
|
kwargs.update(
|
179
188
|
{
|
@@ -266,7 +275,7 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
266
275
|
cls,
|
267
276
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
268
277
|
model_config: "PretrainedConfig",
|
269
|
-
rbln_batch_size: Optional[int] =
|
278
|
+
rbln_batch_size: Optional[int] = None,
|
270
279
|
) -> RBLNConfig:
|
271
280
|
meta = {}
|
272
281
|
|
@@ -287,6 +296,7 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
287
296
|
if rbln_dec_max_seq_len is None:
|
288
297
|
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
289
298
|
|
299
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
290
300
|
decoder_batch_size = rbln_batch_size
|
291
301
|
|
292
302
|
meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.1
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
|
5
5
|
It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
6
6
|
Keywords: transformers,diffusers,inference,rbln,atom,rebel
|