optimum-rbln 0.1.0__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +7 -0
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -9
  5. optimum/rbln/diffusers/models/controlnet.py +93 -23
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +78 -61
  7. optimum/rbln/diffusers/pipelines/__init__.py +7 -2
  8. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +4 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +768 -0
  10. optimum/rbln/diffusers/pipelines/{stable_diffusion → controlnet}/pipeline_controlnet_img2img.py +25 -16
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +942 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +955 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -4
  15. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -9
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +19 -3
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +19 -3
  18. optimum/rbln/modeling_base.py +39 -6
  19. optimum/rbln/modeling_seq2seq.py +19 -4
  20. optimum/rbln/transformers/__init__.py +2 -0
  21. optimum/rbln/transformers/generation/__init__.py +1 -0
  22. optimum/rbln/transformers/generation/streamers.py +17 -0
  23. optimum/rbln/transformers/generation/utils.py +399 -0
  24. optimum/rbln/transformers/models/__init__.py +1 -0
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +24 -333
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +49 -17
  27. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +759 -0
  28. optimum/rbln/transformers/models/llama/modeling_llama.py +187 -75
  29. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  32. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  33. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  34. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  35. optimum/rbln/transformers/models/midm/modeling_midm.py +426 -0
  36. optimum/rbln/transformers/models/whisper/modeling_whisper.py +13 -3
  37. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/METADATA +5 -4
  38. optimum_rbln-0.1.4.dist-info/RECORD +63 -0
  39. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/WHEEL +1 -1
  40. optimum_rbln-0.1.0.dist-info/RECORD +0 -51
  41. {optimum_rbln-0.1.0.dist-info → optimum_rbln-0.1.4.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,6 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- import warnings
27
26
  from pathlib import Path
28
27
  from tempfile import TemporaryDirectory
29
28
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -32,19 +31,13 @@ import rebel
32
31
  import torch
33
32
  from optimum.exporters import TasksManager
34
33
  from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
35
- from transformers.generation.logits_process import LogitsProcessorList
36
- from transformers.generation.stopping_criteria import (
37
- StoppingCriteriaList,
38
- validate_stopping_criteria,
39
- )
40
- from transformers.generation.streamers import BaseStreamer
41
- from transformers.generation.utils import SampleDecoderOnlyOutput
42
34
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
43
35
 
44
36
  from ....modeling_base import RBLNBaseModel
45
37
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
46
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
47
39
  from ....utils.save_utils import maybe_save_preprocessors
40
+ from ...generation.utils import RBLNGenerationMixin
48
41
  from .gpt2_architecture import GPT2LMHeadModelWrapper
49
42
 
50
43
 
@@ -66,7 +59,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
66
59
  return Seq2SeqLMOutput(logits=logits)
67
60
 
68
61
 
69
- class RBLNGPT2LMHeadModel(RBLNBaseModel):
62
+ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
70
63
  """
71
64
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
72
65
  embeddings).
@@ -135,6 +128,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
135
128
  subfolder: str = "",
136
129
  local_files_only: bool = False,
137
130
  trust_remote_code: bool = False,
131
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
138
132
  **kwargs,
139
133
  ) -> "RBLNGPT2LMHeadModel":
140
134
  """
@@ -144,8 +138,16 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
144
138
  if task is None:
145
139
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
146
140
 
147
- save_dir = TemporaryDirectory()
148
- save_dir_path = Path(save_dir.name)
141
+ if model_save_dir is None:
142
+ save_dir = TemporaryDirectory()
143
+ save_dir_path = Path(save_dir.name)
144
+ else:
145
+ save_dir = model_save_dir
146
+ if isinstance(save_dir, TemporaryDirectory):
147
+ save_dir_path = Path(model_save_dir.name)
148
+ else:
149
+ save_dir_path = Path(model_save_dir)
150
+ save_dir_path.mkdir(exist_ok=True)
149
151
 
150
152
  kwargs.update(
151
153
  {
@@ -264,8 +266,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
264
266
  meta["rbln_max_seq_len"] = rbln_max_seq_len
265
267
  meta["rbln_pad_token_id"] = rbln_pad_token_id
266
268
 
267
- if rbln_batch_size is None:
268
- rbln_batch_size = 1
269
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
269
270
 
270
271
  def get_input_info(query_length):
271
272
  return [
@@ -320,6 +321,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
320
321
  self.prompt_ids = input_ids
321
322
  self.rightpad_max_len = cur_len
322
323
  prompt_min_len = torch.min(torch.sum(attention_mask, dim=-1))
324
+ self.dummy_len = torch.sum(attention_mask, dim=-1) - prompt_min_len
323
325
 
324
326
  if cur_len % self.prefill_chunk_size == 0:
325
327
  pad_len = 0
@@ -329,12 +331,12 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
329
331
  attention_mask = self.prefill_attention_mask.clone()
330
332
  cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
331
333
 
332
- query_length = prompt_min_len
334
+ query_length = prompt_min_len.item()
333
335
  else:
334
336
  cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
335
337
  attention_mask = torch.zeros(batch_size, 1, 1, self.max_seq_len, dtype=torch.int64)
336
338
  attention_mask[:, :, :, : cache_position + 1] = 1
337
- input_ids = input_ids[:, -1:].contiguous()
339
+ input_ids = input_ids[:, cache_position : cache_position + 1].contiguous()
338
340
  query_length = 1
339
341
 
340
342
  model_inputs = {
@@ -357,25 +359,23 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
357
359
  query_length: Optional[torch.Tensor] = None,
358
360
  **kwargs,
359
361
  ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
360
-
361
362
  if past_key_values is not None:
362
363
  past_key_values += query_length
363
364
 
364
365
  if cache_position == 0:
365
- for _ in range(0, query_length, self.prefill_chunk_size):
366
- sliced_input_ids = input_ids[:, cache_position : cache_position + self.prefill_chunk_size]
367
- attention_mask[:, :, :, :cache_position] = 1
368
- attention_mask[:, :, :, cache_position : cache_position + self.prefill_chunk_size] = self.causal_mask
366
+ for step in range(0, query_length, self.prefill_chunk_size):
367
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
368
+ attention_mask[:, :, :, :step] = 1
369
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
369
370
 
370
371
  output = self.prefill_decoder(
371
372
  input_ids=sliced_input_ids.contiguous(),
372
373
  attention_mask=attention_mask.contiguous(),
373
- cache_position=cache_position,
374
+ cache_position=cache_position + step,
374
375
  )
375
- query_length -= self.prefill_chunk_size
376
- cache_position += self.prefill_chunk_size
377
376
 
378
- output = output.logits[:, query_length - 1].unsqueeze(1)
377
+ idx = query_length % self.prefill_chunk_size - 1
378
+ output = output.logits[:, idx].unsqueeze(1)
379
379
 
380
380
  else:
381
381
  output = self.decoder(
@@ -389,312 +389,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel):
389
389
 
390
390
  def __repr__(self):
391
391
  return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
392
-
393
- # call 'greedy_search` directly is deprecated and removed in v4.41.
394
- def greedy_search(self, *args, **kwargs):
395
- return self._greedy_search(*args, **kwargs)
396
-
397
- def _greedy_search(
398
- self,
399
- input_ids: torch.LongTensor,
400
- logits_processor: Optional[LogitsProcessorList] = None,
401
- stopping_criteria: Optional[StoppingCriteriaList] = None,
402
- max_length: Optional[int] = None,
403
- pad_token_id: Optional[int] = None,
404
- eos_token_id: Optional[Union[int, List[int]]] = None,
405
- output_logits: Optional[bool] = None,
406
- return_dict_in_generate: Optional[bool] = None,
407
- streamer: Optional["BaseStreamer"] = None,
408
- **model_kwargs,
409
- ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
410
-
411
- # init values
412
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
413
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
414
-
415
- if max_length is not None:
416
- warnings.warn(
417
- "`max_length` is deprecated in this function, use"
418
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
419
- UserWarning,
420
- )
421
- stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
422
-
423
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
424
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
425
- if isinstance(eos_token_id, int):
426
- eos_token_id = [eos_token_id]
427
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
428
-
429
- return_dict_in_generate = (
430
- return_dict_in_generate
431
- if return_dict_in_generate is not None
432
- else self.generation_config.return_dict_in_generate
433
- )
434
-
435
- # init attention / hidden states / scores tuples
436
- raw_logits = () if (return_dict_in_generate and output_logits) else None
437
-
438
- # keep track of which sequences are already finished
439
- unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
440
-
441
- this_peer_finished = False # used by synced_gpus only
442
-
443
- while True:
444
- # prepare model inputs
445
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
446
-
447
- # forward pass to get next token
448
- outputs = self(
449
- **model_inputs,
450
- return_dict=True,
451
- )
452
- next_token_logits = outputs.logits[:, -1, :]
453
-
454
- # pre-process distribution
455
- next_tokens_scores = logits_processor(input_ids, next_token_logits)
456
-
457
- # Store scores, attentions and hidden_states when required
458
- if return_dict_in_generate:
459
- if output_logits:
460
- raw_logits += (next_token_logits,)
461
-
462
- # argmax
463
- next_tokens = torch.argmax(next_tokens_scores, dim=-1)
464
-
465
- # finished sentences should have their next token be a padding token
466
- if eos_token_id is not None:
467
- if pad_token_id is None:
468
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
469
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
470
-
471
- ########################################################################################################
472
- # thkim change for right-padding batch
473
- # if min_input_len <= update_idx < max_input_len
474
- # update validate input_ids[:,update_idx]
475
- # TODO : raw_logits contains dummy next_token's logits
476
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
477
- if update_idx < self.rightpad_max_len:
478
- # update exist input_ids rather than concat
479
- valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
480
- input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
481
- model_kwargs["attention_mask"][valid_indices, update_idx] = 1
482
-
483
- # dummy next_token -> pad_token_id for streamer
484
- # in order to skip by 'skip_special_tokens = True"
485
- dummy_indices = ~valid_indices
486
- next_tokens[dummy_indices] = pad_token_id
487
- else:
488
- ############################################END#########################################################
489
- # update generated ids, model inputs, and length for next step
490
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
491
-
492
- model_kwargs = self._update_model_kwargs_for_generation(
493
- outputs,
494
- model_kwargs,
495
- is_encoder_decoder=self.config.is_encoder_decoder,
496
- )
497
-
498
- if streamer is not None:
499
- streamer.put(next_tokens.cpu())
500
-
501
- # if eos_token was found in one sentence, set sentence to finished
502
- if eos_token_id_tensor is not None:
503
- ####################################################################
504
- # thkim : to do not finish sequence of dummy_decoder of right_padding
505
- if hasattr(self, "rightpad_max_len"):
506
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
507
- if update_idx < self.rightpad_max_len:
508
- next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
509
- ######################################################################
510
- unfinished_sequences = unfinished_sequences.mul(
511
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
512
- )
513
-
514
- # stop when each sentence is finished
515
- if unfinished_sequences.max() == 0:
516
- this_peer_finished = True
517
-
518
- # stop if we exceed the maximum length
519
- # thkim : backward compatibility bool vs torch.BoolTensor
520
- is_stop = stopping_criteria(input_ids, None)
521
- if isinstance(is_stop, torch.BoolTensor):
522
- is_stop = torch.all(is_stop)
523
- if is_stop:
524
- this_peer_finished = True
525
-
526
- if this_peer_finished:
527
- break
528
-
529
- if streamer is not None:
530
- streamer.end()
531
-
532
- if return_dict_in_generate:
533
- return SampleDecoderOnlyOutput(
534
- sequences=input_ids,
535
- logits=raw_logits,
536
- )
537
- else:
538
- return input_ids
539
-
540
- # call 'sample` directly is deprecated and removed in v4.41.
541
- def sample(self, *args, **kwargs):
542
- return self._sample(*args, **kwargs)
543
-
544
- def _sample(
545
- self,
546
- input_ids: torch.LongTensor,
547
- logits_processor: Optional[LogitsProcessorList] = None,
548
- stopping_criteria: Optional[StoppingCriteriaList] = None,
549
- logits_warper: Optional[LogitsProcessorList] = None,
550
- max_length: Optional[int] = None,
551
- pad_token_id: Optional[int] = None,
552
- eos_token_id: Optional[Union[int, List[int]]] = None,
553
- output_attentions: Optional[bool] = None,
554
- output_hidden_states: Optional[bool] = None,
555
- output_scores: Optional[bool] = None,
556
- output_logits: Optional[bool] = None,
557
- return_dict_in_generate: Optional[bool] = None,
558
- synced_gpus: bool = False,
559
- streamer: Optional["BaseStreamer"] = None,
560
- **model_kwargs,
561
- ) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
562
- # init values
563
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
564
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
565
-
566
- if max_length is not None:
567
- warnings.warn(
568
- "`max_length` is deprecated in this function, use"
569
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
570
- UserWarning,
571
- )
572
- stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
573
-
574
- logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
575
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
576
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
577
-
578
- if isinstance(eos_token_id, int):
579
- eos_token_id = [eos_token_id]
580
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
581
-
582
- output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
583
- output_logits = output_logits if output_logits is not None else False
584
-
585
- # init attention / hidden states / scores tuples
586
- scores = () if (return_dict_in_generate and output_scores) else None
587
- raw_logits = () if (return_dict_in_generate and output_logits) else None
588
-
589
- # keep track of which sequences are already finished
590
- batch_size, cur_len = input_ids.shape
591
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
592
- this_peer_finished = False
593
-
594
- # model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
595
-
596
- while True:
597
- # prepare model inputs
598
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
599
-
600
- # forward pass to get next token
601
- outputs = self(
602
- **model_inputs,
603
- return_dict=True,
604
- output_attentions=output_attentions,
605
- output_hidden_states=output_hidden_states,
606
- )
607
-
608
- next_token_logits = outputs.logits[:, -1, :]
609
-
610
- # pre-process distribution
611
- next_token_scores = logits_processor(input_ids, next_token_logits)
612
- next_token_scores = logits_warper(input_ids, next_token_scores)
613
-
614
- # Store scores, attentions and hidden_states when required
615
- if return_dict_in_generate:
616
- if output_scores:
617
- scores += (next_token_scores,)
618
- if output_logits:
619
- raw_logits += (next_token_logits,)
620
-
621
- # sample
622
- probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
623
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
624
-
625
- # finished sentences should have their next token be a padding token
626
- if eos_token_id is not None:
627
- if pad_token_id is None:
628
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
629
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
630
-
631
- ########################################################################################################
632
- # thkim change for right-padding batch
633
- # if min_input_len <= update_idx < max_input_len
634
- # update validate input_ids[:,update_idx]
635
- # TODO : raw_logits contains dummy next_token's logits
636
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
637
- if update_idx < self.rightpad_max_len:
638
- # update exist input_ids rather than concat
639
- valid_indices = model_kwargs["attention_mask"][:, update_idx] == 0
640
- input_ids[valid_indices, update_idx] = next_tokens[valid_indices]
641
- model_kwargs["attention_mask"][valid_indices, update_idx] = 1
642
-
643
- # dummy next_token -> pad_token_id for streamer
644
- # in order to skip by 'skip_special_tokens = True"
645
- dummy_indices = ~valid_indices
646
- next_tokens[dummy_indices] = pad_token_id
647
- else:
648
- ############################################END#########################################################
649
-
650
- # update generated ids, model inputs, and length for next step
651
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
652
-
653
- model_kwargs = self._update_model_kwargs_for_generation(
654
- outputs,
655
- model_kwargs,
656
- is_encoder_decoder=self.config.is_encoder_decoder,
657
- )
658
-
659
- if streamer is not None:
660
- streamer.put(next_tokens.cpu())
661
-
662
- # if eos_token was found in one sentence, set sentence to finished
663
- if eos_token_id_tensor is not None:
664
- ####################################################################
665
- # thkim : to do not finish sequence of dummy_decoder of right_padding
666
- if hasattr(self, "rightpad_max_len"):
667
- update_idx = model_inputs["cache_position"] + model_inputs["query_length"]
668
- if update_idx < self.rightpad_max_len:
669
- next_tokens += model_kwargs["attention_mask"][:, update_idx] * eos_token_id_tensor
670
- ######################################################################
671
- unfinished_sequences = unfinished_sequences.mul(
672
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
673
- )
674
-
675
- # stop when each sentence is finished
676
- if unfinished_sequences.max() == 0:
677
- this_peer_finished = True
678
-
679
- # stop if we exceed the maximum length
680
- # thkim : backward compatibility bool vs list[bool]
681
- is_stop = stopping_criteria(input_ids, None)
682
- if isinstance(is_stop, torch.BoolTensor):
683
- is_stop = torch.all(is_stop)
684
- if is_stop:
685
- this_peer_finished = True
686
-
687
- if this_peer_finished:
688
- break
689
-
690
- if streamer is not None:
691
- streamer.end()
692
-
693
- if return_dict_in_generate:
694
- return SampleDecoderOnlyOutput(
695
- sequences=input_ids,
696
- scores=scores,
697
- logits=raw_logits,
698
- )
699
- else:
700
- return input_ids
@@ -36,7 +36,6 @@ from transformers.models.llama.modeling_llama import (
36
36
  LlamaForCausalLM,
37
37
  LlamaModel,
38
38
  LlamaRotaryEmbedding,
39
- repeat_kv,
40
39
  )
41
40
 
42
41
 
@@ -149,26 +148,41 @@ class _LlamaAttention(LlamaAttention):
149
148
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
150
149
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
151
150
 
151
+ # change to remove repeat
152
+ key_states = key_states.unsqueeze(2)
153
+ value_states = value_states.unsqueeze(2)
154
+ query_states = query_states.view(
155
+ bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
156
+ )
157
+
152
158
  if past_key_value is not None:
153
159
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
154
160
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
155
161
 
156
- key_states = repeat_kv(key_states, self.num_key_value_groups)
157
- value_states = repeat_kv(value_states, self.num_key_value_groups)
162
+ # change to remove repeat
163
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
164
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
158
165
 
159
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
166
+ # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
160
167
 
161
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
162
- raise ValueError(
163
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
164
- f" {attn_weights.size()}"
165
- )
168
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
169
+
170
+ # change to remove repeat
171
+ # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
172
+ # raise ValueError(
173
+ # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
174
+ # f" {attn_weights.size()}"
175
+ # )
166
176
 
167
177
  if attention_mask is not None:
168
178
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
169
179
  raise ValueError(
170
180
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
171
181
  )
182
+ else:
183
+ # change to remove repeat
184
+ attention_mask = attention_mask.unsqueeze(2)
185
+
172
186
  attn_weights = attn_weights + attention_mask
173
187
 
174
188
  # upcast attention to fp32
@@ -176,6 +190,9 @@ class _LlamaAttention(LlamaAttention):
176
190
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
177
191
  attn_output = torch.matmul(attn_weights, value_states)
178
192
 
193
+ # change to remove repeat
194
+ attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
195
+
179
196
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
180
197
  raise ValueError(
181
198
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
@@ -516,17 +533,32 @@ class RebelDynamicCache(DynamicCache):
516
533
  if len(self.key_cache) <= layer_idx:
517
534
  self.key_cache.append(key_states)
518
535
  self.value_cache.append(value_states)
536
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
519
537
  else:
520
- self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
521
- key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
538
+ # change to remove repeat
539
+ # self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
540
+ # key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
541
+ # )
542
+ # self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
543
+ # value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
544
+ # )
545
+ updated_key = (
546
+ self.key_cache[layer_idx]
547
+ .unsqueeze(2)
548
+ .slice_scatter(
549
+ key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
550
+ )
522
551
  )
523
- self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
524
- value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
552
+ updated_value = (
553
+ self.value_cache[layer_idx]
554
+ .unsqueeze(2)
555
+ .slice_scatter(
556
+ value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
557
+ )
525
558
  )
526
- # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
527
- # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
528
-
529
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
559
+ self.key_cache[layer_idx] = updated_key.squeeze(2)
560
+ self.value_cache[layer_idx] = updated_value.squeeze(2)
561
+ return updated_key, updated_value
530
562
 
531
563
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
532
564
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""