optimum-rbln 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. optimum/rbln/__init__.py +6 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +7 -0
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
  5. optimum/rbln/diffusers/models/controlnet.py +93 -23
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
  7. optimum/rbln/diffusers/pipelines/__init__.py +7 -2
  8. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
  10. optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
  15. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
  18. optimum/rbln/modeling_base.py +36 -3
  19. optimum/rbln/modeling_seq2seq.py +19 -4
  20. optimum/rbln/transformers/generation/__init__.py +1 -0
  21. optimum/rbln/transformers/generation/streamers.py +17 -0
  22. optimum/rbln/transformers/generation/utils.py +399 -0
  23. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
  24. optimum/rbln/transformers/models/llama/modeling_llama.py +63 -45
  25. optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
  26. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/METADATA +1 -1
  27. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/RECORD +29 -25
  28. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/WHEEL +0 -0
  29. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,6 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- import warnings
27
26
  from pathlib import Path
28
27
  from tempfile import TemporaryDirectory
29
28
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -32,19 +31,13 @@ import rebel
32
31
  import torch
33
32
  from optimum.exporters import TasksManager
34
33
  from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
35
- from transformers.generation.logits_process import LogitsProcessorList
36
- from transformers.generation.stopping_criteria import (
37
- StoppingCriteriaList,
38
- validate_stopping_criteria,
39
- )
40
- from transformers.generation.streamers import BaseStreamer
41
- from transformers.generation.utils import SampleDecoderOnlyOutput
42
34
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
43
35
 
44
36
  from ....modeling_base import RBLNBaseModel
45
37
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
46
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
47
39
  from ....utils.save_utils import maybe_save_preprocessors
40
+ from ...generation.utils import RBLNGenerationMixin
48
41
  from .gpt2_architecture import GPT2LMHeadModelWrapper
49
42
 
50
43
 
@@ -66,7 +59,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
66
59
  return Seq2SeqLMOutput(logits=logits)
67
60
 
68
61
 
69
- class RBLNGPT2LMHeadModel(RBLNBaseModel):
62
+ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
70
63
  """
71
64
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
72
65
  embeddings).
@@ -135,6 +128,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
135
128
  subfolder: str = "",
136
129
  local_files_only: bool = False,
137
130
  trust_remote_code: bool = False,
131
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
138
132
  **kwargs,
139
133
  ) -> "RBLNGPT2LMHeadModel":
140
134
  """
@@ -144,8 +138,16 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
144
138
  if task is None:
145
139
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
146
140
 
147
- save_dir = TemporaryDirectory()
148
- save_dir_path = Path(save_dir.name)
141
+ if model_save_dir is None:
142
+ save_dir = TemporaryDirectory()
143
+ save_dir_path = Path(save_dir.name)
144
+ else:
145
+ save_dir = model_save_dir
146
+ if isinstance(save_dir, TemporaryDirectory):
147
+ save_dir_path = Path(model_save_dir.name)
148
+ else:
149
+ save_dir_path = Path(model_save_dir)
150
+ save_dir_path.mkdir(exist_ok=True)
149
151
 
150
152
  kwargs.update(
151
153
  {
@@ -264,8 +266,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
264
266
  meta["rbln_max_seq_len"] = rbln_max_seq_len
265
267
  meta["rbln_pad_token_id"] = rbln_pad_token_id
266
268
 
267
- if rbln_batch_size is None:
268
- rbln_batch_size = 1
269
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
269
270
 
270
271
  def get_input_info(query_length):
271
272
  return [
@@ -320,6 +321,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
320
321
  self.prompt_ids = input_ids
321
322
  self.rightpad_max_len = cur_len
322
323
  prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
324
+ self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
323
325
 
324
326
  if cur_len % self.prefill_chunk_size == 0:
325
327
  pad_len = 0
@@ -329,12 +331,12 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
329
331
  attention_mask = self.prefill_attention_mask.clone()
330
332
  cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
331
333
 
332
- query_length = prompt_min_len
334
+ query_length = prompt_min_len.item()
333
335
  else:
334
336
  cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
335
337
  attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
336
338
  attention_mask[:, :, :, : cache_position + 1] = 1
337
- input_ids = input_ids[:, -1:].contiguous()
339
+ input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
338
340
  query_length = 1
339
341
 
340
342
  model_inputs = {
@@ -357,25 +359,23 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
357
359
  query_length: Optional[torch.Tensor] = None,
358
360
  **kwargs,
359
361
  ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
360
-
361
362
  if past_key_values is not None:
362
363
  past_key_values += query_length
363
364
 
364
365
  if cache_position == 0:
365
- for _ in range(0, query_length, self.prefill_chunk_size):
366
- sliced_input_ids = input_ids[:, cache_position : cache_position + self.prefill_chunk_size]
367
- attention_mask[:, :, :, :cache_position] = 1
368
- attention_mask[:, :, :, cache_position : cache_position + self.prefill_chunk_size] = self.causal_mask
366
+ for step in range(0, query_length, self.prefill_chunk_size):
367
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
368
+ attention_mask[:, :, :, :step] = 1
369
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
369
370
 
370
371
  output = self.prefill_decoder(
371
372
  input_ids=sliced_input_ids.contiguous(),
372
373
  attention_mask=attention_mask.contiguous(),
373
- cache_position=cache_position,
374
+ cache_position=cache_position + step,
374
375
  )
375
- query_length -= self.prefill_chunk_size
376
- cache_position += self.prefill_chunk_size
377
376
 
378
- output = output.logits[:, query_length - 1].unsqueeze(1)
377
+ idx = query_length % self.prefill_chunk_size - 1
378
+ output = output.logits[:, idx].unsqueeze(1)
379
379
 
380
380
  else:
381
381
  output = self.decoder(
@@ -389,312 +389,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
389
389
 
390
390
  def __repr__(self):
391
391
  return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
392
-
393
- # call 'greedy_search` directly is deprecated and removed in v4.41.
394
- def greedy_search(self, *args, **kwargs):
395
- return self._greedy_search(*args, **kwargs)
396
-
397
- def _greedy_search(
398
- self,
399
- input_ids: torch.LongTensor,
400
- logits_processor: Optional[LogitsProcessorList] = None,
401
- stopping_criteria: Optional[StoppingCriteriaList] = None,
402
- max_length: Optional[int] = None,
403
- pad_token_id: Optional[int] = None,
404
- eos_token_id: Optional[Union[int, List[int]]] = None,
405
- output_logits: Optional[bool] = None,
406
- return_dict_in_generate: Optional[bool] = None,
407
- streamer: Optional["BaseStreamer"] = None,
408
- **model_kwargs,
409
- ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
410
-
411
- # init values
412
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
413
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
414
-
415
- if max_length is not None:
416
- warnings.warn(
417
- "`max_length` is deprecated in this function, use"
418
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
419
- UserWarning,
420
- )
421
- stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
422
-
423
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
424
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
425
- if isinstance(eos_token_id, int):
426
- eos_token_id = [eos_token_id]
427
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
428
-
429
- return_dict_in_generate = (
430
- return_dict_in_generate
431
- if return_dict_in_generate is not None
432
- else self.generation_config.return_dict_in_generate
433
- )
434
-
435
- # init attention / hidden states / scores tuples
436
- raw_logits = () if (return_dict_in_generate and output_logits) else None
437
-
438
- # keep track of which sequences are already finished
439
- unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
440
-
441
- this_peer_finished = False # used by synced_gpus only
442
-
443
- while True:
444
- # prepare model inputs
445
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
446
-
447
- # forward pass to get next token
448
- outputs = self(
449
- **model_inputs,
450
- return_dict=True,
451
- )
452
- next_token_logits = outputs.logits[:, -1, :]
453
-
454
- # pre-process distribution
455
- next_tokens_scores = logits_processor(input_ids, next_token_logits)
456
-
457
- # Store scores, attentions and hidden_states when required
458
- if return_dict_in_generate:
459
- if output_logits:
460
- raw_logits += (next_token_logits,)
461
-
462
- # argmax
463
- next_tokens = torch.argmax(next_tokens_scores, dim=-1)
464
-
465
- # finished sentences should have their next token be a padding token
466
- if eos_token_id is not None:
467
- if pad_token_id is None:
468
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
469
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
470
-
471
- ########################################################################################################
472
- # thkim change for right-padding batch
473
- # if min_input_len <= update_idx < max_input_len
474
- # update validate input_ids[:,update_idx]
475
- # TODO : raw_logits contains dummy next_token's logits
476
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
477
- if update_idx < self.rightpad_max_len:
478
- # update exist input_ids rather than concat
479
- valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
480
- input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
481
- model_kwargs["attention_mask"][valid_indices, update_idx] = 1
482
-
483
- # dummy next_token -> pad_token_id for streamer
484
- # in order to skip by 'skip_special_tokens = True"
485
- dummy_indices = ~valid_indices
486
- next_tokens[dummy_indices] = pad_token_id
487
- else:
488
- ############################################END#########################################################
489
- # update generated ids, model inputs, and length for next step
490
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
491
-
492
- model_kwargs = self._update_model_kwargs_for_generation(
493
- outputs,
494
- model_kwargs,
495
- is_encoder_decoder=self.config.is_encoder_decoder,
496
- )
497
-
498
- if streamer is not None:
499
- streamer.put(next_tokens.cpu())
500
-
501
- # if eos_token was found in one sentence, set sentence to finished
502
- if eos_token_id_tensor is not None:
503
- ####################################################################
504
- # thkim : to do not finish sequence of dummy_decoder of right_padding
505
- if hasattr(self, "rightpad_max_len"):
506
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
507
- if update_idx < self.rightpad_max_len:
508
- next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
509
- ######################################################################
510
- unfinished_sequences = unfinished_sequences.mul(
511
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
512
- )
513
-
514
- # stop when each sentence is finished
515
- if unfinished_sequences.max() == 0:
516
- this_peer_finished = True
517
-
518
- # stop if we exceed the maximum length
519
- # thkim : backward compatibility bool vs torch.BoolTensor
520
- is_stop = stopping_criteria(input_ids, None)
521
- if isinstance(is_stop, torch.BoolTensor):
522
- is_stop = torch.all(is_stop)
523
- if is_stop:
524
- this_peer_finished = True
525
-
526
- if this_peer_finished:
527
- break
528
-
529
- if streamer is not None:
530
- streamer.end()
531
-
532
- if return_dict_in_generate:
533
- return SampleDecoderOnlyOutput(
534
- sequences=input_ids,
535
- logits=raw_logits,
536
- )
537
- else:
538
- return input_ids
539
-
540
- # call 'sample` directly is deprecated and removed in v4.41.
541
- def sample(self, *args, **kwargs):
542
- return self._sample(*args, **kwargs)
543
-
544
- def _sample(
545
- self,
546
- input_ids: torch.LongTensor,
547
- logits_processor: Optional[LogitsProcessorList] = None,
548
- stopping_criteria: Optional[StoppingCriteriaList] = None,
549
- logits_warper: Optional[LogitsProcessorList] = None,
550
- max_length: Optional[int] = None,
551
- pad_token_id: Optional[int] = None,
552
- eos_token_id: Optional[Union[int, List[int]]] = None,
553
- output_attentions: Optional[bool] = None,
554
- output_hidden_states: Optional[bool] = None,
555
- output_scores: Optional[bool] = None,
556
- output_logits: Optional[bool] = None,
557
- return_dict_in_generate: Optional[bool] = None,
558
- synced_gpus: bool = False,
559
- streamer: Optional["BaseStreamer"] = None,
560
- **model_kwargs,
561
- ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
562
- # init values
563
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
564
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
565
-
566
- if max_length is not None:
567
- warnings.warn(
568
- "`max_length` is deprecated in this function, use"
569
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
570
- UserWarning,
571
- )
572
- stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
573
-
574
- logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
575
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
576
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
577
-
578
- if isinstance(eos_token_id, int):
579
- eos_token_id = [eos_token_id]
580
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
581
-
582
- output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
583
- output_logits = output_logits if output_logits is not None else False
584
-
585
- # init attention / hidden states / scores tuples
586
- scores = () if (return_dict_in_generate and output_scores) else None
587
- raw_logits = () if (return_dict_in_generate and output_logits) else None
588
-
589
- # keep track of which sequences are already finished
590
- batch_size, cur_len = input_ids.shape
591
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
592
- this_peer_finished = False
593
-
594
- # model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
595
-
596
- while True:
597
- # prepare model inputs
598
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
599
-
600
- # forward pass to get next token
601
- outputs = self(
602
- **model_inputs,
603
- return_dict=True,
604
- output_attentions=output_attentions,
605
- output_hidden_states=output_hidden_states,
606
- )
607
-
608
- next_token_logits = outputs.logits[:, -1, :]
609
-
610
- # pre-process distribution
611
- next_token_scores = logits_processor(input_ids, next_token_logits)
612
- next_token_scores = logits_warper(input_ids, next_token_scores)
613
-
614
- # Store scores, attentions and hidden_states when required
615
- if return_dict_in_generate:
616
- if output_scores:
617
- scores += (next_token_scores,)
618
- if output_logits:
619
- raw_logits += (next_token_logits,)
620
-
621
- # sample
622
- probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
623
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
624
-
625
- # finished sentences should have their next token be a padding token
626
- if eos_token_id is not None:
627
- if pad_token_id is None:
628
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
629
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
630
-
631
- ########################################################################################################
632
- # thkim change for right-padding batch
633
- # if min_input_len <= update_idx < max_input_len
634
- # update validate input_ids[:,update_idx]
635
- # TODO : raw_logits contains dummy next_token's logits
636
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
637
- if update_idx < self.rightpad_max_len:
638
- # update exist input_ids rather than concat
639
- valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
640
- input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
641
- model_kwargs["attention_mask"][valid_indices, update_idx] = 1
642
-
643
- # dummy next_token -> pad_token_id for streamer
644
- # in order to skip by 'skip_special_tokens = True"
645
- dummy_indices = ~valid_indices
646
- next_tokens[dummy_indices] = pad_token_id
647
- else:
648
- ############################################END#########################################################
649
-
650
- # update generated ids, model inputs, and length for next step
651
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
652
-
653
- model_kwargs = self._update_model_kwargs_for_generation(
654
- outputs,
655
- model_kwargs,
656
- is_encoder_decoder=self.config.is_encoder_decoder,
657
- )
658
-
659
- if streamer is not None:
660
- streamer.put(next_tokens.cpu())
661
-
662
- # if eos_token was found in one sentence, set sentence to finished
663
- if eos_token_id_tensor is not None:
664
- ####################################################################
665
- # thkim : to do not finish sequence of dummy_decoder of right_padding
666
- if hasattr(self, "rightpad_max_len"):
667
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
668
- if update_idx < self.rightpad_max_len:
669
- next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
670
- ######################################################################
671
- unfinished_sequences = unfinished_sequences.mul(
672
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
673
- )
674
-
675
- # stop when each sentence is finished
676
- if unfinished_sequences.max() == 0:
677
- this_peer_finished = True
678
-
679
- # stop if we exceed the maximum length
680
- # thkim : backward compatibility bool vs list[bool]
681
- is_stop = stopping_criteria(input_ids, None)
682
- if isinstance(is_stop, torch.BoolTensor):
683
- is_stop = torch.all(is_stop)
684
- if is_stop:
685
- this_peer_finished = True
686
-
687
- if this_peer_finished:
688
- break
689
-
690
- if streamer is not None:
691
- streamer.end()
692
-
693
- if return_dict_in_generate:
694
- return SampleDecoderOnlyOutput(
695
- sequences=input_ids,
696
- scores=scores,
697
- logits=raw_logits,
698
- )
699
- else:
700
- return input_ids
@@ -34,6 +34,7 @@ from optimum.exporters import TasksManager
34
34
  from transformers import AutoModelForCausalLM, LlamaForCausalLM, PretrainedConfig, AutoConfig
35
35
  from transformers.modeling_outputs import CausalLMOutputWithPast
36
36
 
37
+ from ...generation.utils import RBLNGenerationMixin
37
38
  from ....modeling_base import RBLNBaseModel
38
39
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
39
40
  from ....utils.runtime_utils import RBLNPytorchRuntime
@@ -75,7 +76,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
75
76
  return logits
76
77
 
77
78
 
78
- class RBLNLlamaForCausalLM(RBLNBaseModel):
79
+ class RBLNLlamaForCausalLM(RBLNBaseModel, RBLNGenerationMixin):
79
80
  """
80
81
  The Llama Model transformer with a language modeling head (linear layer) on top.
81
82
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -91,7 +92,6 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
91
92
  auto_model_class = AutoModelForCausalLM
92
93
 
93
94
  def __post_init__(self, **kwargs):
94
-
95
95
  self.batch_size = self.rbln_config.meta["rbln_batch_size"]
96
96
  self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
97
97
  self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
@@ -106,6 +106,7 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
106
106
  self.prefill_decoder = RBLNRuntimeModel(runtime=self.runtimes[0], main_input_name="input_ids")
107
107
  self.decoder = RBLNRuntimeModel(runtime=self.runtimes[1], main_input_name="input_ids")
108
108
  self.past_cached_length = 0
109
+ self.right_padding = True
109
110
 
110
111
  @classmethod
111
112
  @torch.no_grad()
@@ -120,14 +121,23 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
120
121
  subfolder: str = "",
121
122
  local_files_only: bool = False,
122
123
  trust_remote_code: bool = False,
124
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
123
125
  **kwargs,
124
126
  ) -> "RBLNLlamaForCausalLM":
125
127
  task = kwargs.pop("task", None)
126
128
  if task is None:
127
129
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
128
130
 
129
- save_dir = TemporaryDirectory()
130
- save_dir_path = Path(save_dir.name)
131
+ if model_save_dir is None:
132
+ save_dir = TemporaryDirectory()
133
+ save_dir_path = Path(save_dir.name)
134
+ else:
135
+ save_dir = model_save_dir
136
+ if isinstance(save_dir, TemporaryDirectory):
137
+ save_dir_path = Path(model_save_dir.name)
138
+ else:
139
+ save_dir_path = Path(model_save_dir)
140
+ save_dir_path.mkdir(exist_ok=True)
131
141
 
132
142
  def update_configs(kwargs):
133
143
  hf_max_position_embeddings = getattr(AutoConfig.from_pretrained(model_id), "max_position_embeddings", None)
@@ -245,6 +255,7 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
245
255
  prefill_chunk_size = 128
246
256
  if rbln_max_seq_len is None:
247
257
  rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
258
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
248
259
 
249
260
  meta["rbln_max_seq_len"] = rbln_max_seq_len
250
261
  meta["rbln_batch_size"] = rbln_batch_size
@@ -321,23 +332,47 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
321
332
 
322
333
  # args input_ids, past_key_values and attention_mask are updated by _update_model_kwargs_for_generation() in _greedy_search() in GenerationMixin
323
334
  def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
324
- batch_size, hf_input_length = input_ids.shape
335
+ batch_size, cur_len = input_ids.shape
325
336
  past_cached_length = past_key_values
326
- query_length = hf_input_length - past_cached_length
327
337
 
328
338
  # In greedy decoding
329
- if past_key_values == 0:
330
- self.prompt_length = query_length
331
- self.prompt_ids = input_ids
332
- self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
333
-
334
- attention_mask = torch.zeros(batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64)
339
+ if past_cached_length == 0:
340
+
341
+ # padding with prefill_chunk_size
342
+ # TODO left padding + left padding has issue on stoppingcriteria(max_len)
343
+ if cur_len % self.prefill_chunk_size != 0:
344
+ pad_len = self.prefill_chunk_size - cur_len % self.prefill_chunk_size
345
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_len))
346
+
347
+ # padding_side
348
+ if batch_size > 1 and torch.all(attention_mask[..., -1] == 1):
349
+ self.right_padding = False
350
+
351
+ if self.right_padding:
352
+ self.rightpad_max_len = cur_len
353
+ prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
354
+ self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len # dummy_decoder generation length
355
+ query_length = prompt_min_len.item()
356
+ else:
357
+ query_length = cur_len - past_cached_length
358
+ self.prompt_length = query_length
359
+ self.prompt_attn_mask = attention_mask.unsqueeze(1).unsqueeze(1).contiguous()
360
+
361
+ attention_mask = self.prefill_attention_mask.clone()
335
362
  cache_position = torch.tensor(0, dtype=torch.int32)
363
+
336
364
  else:
337
- attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - hf_input_length))
338
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
365
+ if self.right_padding:
366
+ attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
367
+ attention_mask[:, :, :, : past_cached_length + 1] = 1
368
+ input_ids = input_ids[:, past_cached_length : past_cached_length + 1].contiguous()
369
+ else:
370
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, self.max_seq_len - cur_len))
371
+ attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
372
+ input_ids = input_ids[:, -1:]
373
+
339
374
  cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
340
- input_ids = input_ids[:, -1:]
375
+ query_length = 1
341
376
 
342
377
  model_inputs = {
343
378
  "input_ids": input_ids,
@@ -358,43 +393,26 @@ class RBLNLlamaForCausalLM(RBLNBaseModel):
358
393
  query_length: Optional[torch.Tensor] = None,
359
394
  **kwargs,
360
395
  ) -> Tuple[torch.FloatTensor]:
396
+
361
397
  if past_key_values is not None:
362
398
  past_key_values += query_length
363
399
 
364
400
  # prefill_decoder
365
401
  if cache_position == 0:
366
- while query_length > self.prefill_chunk_size:
367
- # prepare input_ids & attention_mask
368
- sliced_input_ids = input_ids[:, cache_position : cache_position + self.prefill_chunk_size].contiguous()
369
- attention_mask[:, :, :, :cache_position] = 1
370
- attention_mask[:, :, :, cache_position : cache_position + self.prefill_chunk_size] = self.causal_mask
371
- attention_mask[:, :, :, : self.prompt_length] *= self.prompt_attn_mask[:, :, :, :]
372
-
373
- _ = self.prefill_decoder(
374
- sliced_input_ids,
375
- attention_mask,
376
- cache_position,
402
+ for step in range(0, query_length, self.prefill_chunk_size):
403
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
404
+ attention_mask[:, :, :, :step] = 1
405
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
406
+ if not self.right_padding:
407
+ attention_mask[:, :, :, : self.prompt_length] &= self.prompt_attn_mask[:, :, :, :]
408
+
409
+ outputs = self.prefill_decoder(
410
+ input_ids=sliced_input_ids.contiguous(),
411
+ attention_mask=attention_mask.contiguous(),
412
+ cache_position=cache_position + step,
377
413
  )
378
- # update query_length & cache_position
379
- query_length -= self.prefill_chunk_size
380
- cache_position += self.prefill_chunk_size
381
-
382
- # prepare input_ids & attention_mask
383
- last_input_ids = input_ids[:, cache_position : cache_position + query_length]
384
- last_input_ids = torch.nn.functional.pad(last_input_ids, (0, self.prefill_chunk_size - query_length))
385
-
386
- attention_mask[:, :, :, :cache_position] = 1
387
- mask_slice = self.causal_mask[:, :, :query_length, :query_length]
388
- attention_mask[:, :, :query_length, cache_position : cache_position + query_length] = mask_slice
389
- attention_mask[:, :, :, : self.prompt_length] *= self.prompt_attn_mask[:, :, :, :]
390
-
391
- outputs = self.prefill_decoder(
392
- last_input_ids.contiguous(),
393
- attention_mask.contiguous(),
394
- cache_position,
395
- )
414
+ outputs = outputs[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
396
415
 
397
- outputs = outputs[:, query_length - 1].unsqueeze(1)
398
416
  # decoder
399
417
  else:
400
418
  outputs = self.decoder(
@@ -163,6 +163,7 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
163
163
  subfolder: str = "",
164
164
  local_files_only: bool = False,
165
165
  trust_remote_code: bool = False,
166
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
166
167
  **kwargs,
167
168
  ) -> "RBLNWhisperForConditionalGeneration":
168
169
  """
@@ -172,8 +173,16 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
172
173
  if task is None:
173
174
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
174
175
 
175
- save_dir = TemporaryDirectory()
176
- save_dir_path = Path(save_dir.name)
176
+ if model_save_dir is None:
177
+ save_dir = TemporaryDirectory()
178
+ save_dir_path = Path(save_dir.name)
179
+ else:
180
+ save_dir = model_save_dir
181
+ if isinstance(save_dir, TemporaryDirectory):
182
+ save_dir_path = Path(model_save_dir.name)
183
+ else:
184
+ save_dir_path = Path(model_save_dir)
185
+ save_dir_path.mkdir(exist_ok=True)
177
186
 
178
187
  kwargs.update(
179
188
  {
@@ -266,7 +275,7 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
266
275
  cls,
267
276
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
268
277
  model_config: "PretrainedConfig",
269
- rbln_batch_size: Optional[int] = 1,
278
+ rbln_batch_size: Optional[int] = None,
270
279
  ) -> RBLNConfig:
271
280
  meta = {}
272
281
 
@@ -287,6 +296,7 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
287
296
  if rbln_dec_max_seq_len is None:
288
297
  raise ValueError("`rbln_dec_max_seq_len` should be specified!")
289
298
 
299
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
290
300
  decoder_batch_size = rbln_batch_size
291
301
 
292
302
  meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: optimum-rbln
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
5
5
  It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
6
6
  Keywords: transformers,diffusers,inference,rbln,atom,rebel