optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -20,62 +20,77 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
- import torch
23
+
24
+ from typing import TYPE_CHECKING
25
+
26
+ import torch.nn as nn
24
27
 
25
28
  from ....utils import logging
26
- from ...models.decoderonly import (
29
+ from ...models.decoderonly.decoderonly_architecture import (
27
30
  DecoderOnlyAttention,
28
- DecoderOnlyDecoderLayer,
31
+ DecoderOnlyFlashAttention,
32
+ DecoderOnlyForCausalLM,
33
+ DecoderOnlyLayer,
29
34
  DecoderOnlyModel,
30
35
  DecoderOnlyWrapper,
31
- RotaryEmbedding,
32
36
  )
33
37
 
34
38
 
39
+ if TYPE_CHECKING:
40
+ from transformers import PreTrainedModel as ExaoneForCausalLM
41
+
35
42
  logger = logging.get_logger(__name__)
36
43
 
37
44
 
38
45
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
39
46
  """A wrapper class for the Exaone model with a language modeling head."""
40
47
 
41
- def __init__(self, model, max_seq_len, kvcache_partition_len=None):
42
- super(DecoderOnlyWrapper, self).__init__()
43
- self.config = model.config
44
- self.model = self.convert_attribute_name(model.transformer)
45
- self.lm_head = model.lm_head
46
- self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
47
-
48
- if kvcache_partition_len is not None:
49
- # WORKAROUND : for passing partition length as a value to the rbln compiler.
50
- # What is actually used is the shape of this tensor.
51
- self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
52
- self.attn_implementation = "flash_attn_rbln"
53
- logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
54
- else:
55
- self.kvcache_partition_size = None
56
- self.attn_implementation = "eager"
57
-
58
- @staticmethod
59
- def convert_attribute_name(model):
60
- model.embed_tokens = model.wte
61
- model.norm = model.ln_f
62
- model.layers = model.h
63
-
64
- for layer in model.layers:
65
- layer.input_layernorm = layer.ln_1
66
- layer.self_attn = layer.attn.attention
67
- layer.post_attention_layernorm = layer.ln_2
68
- layer.self_attn.o_proj = layer.self_attn.out_proj
69
-
70
- return model
71
-
72
- def get_forward_dict(self):
73
- forward_dict = {}
74
- forward_dict.update(
75
- {
76
- "wrapper": DecoderOnlyModel.forward,
77
- "model": DecoderOnlyDecoderLayer.forward,
78
- "decoder_layer": DecoderOnlyAttention.forward,
79
- }
80
- )
81
- return forward_dict
48
+ def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
49
+ new_layers = []
50
+ for layer in causal_lm.transformer.h:
51
+ if self.attn_impl == "eager":
52
+ new_self_attn = ExaoneAttention(layer.attn.attention)
53
+ elif self.attn_impl == "flash_attn":
54
+ new_self_attn = ExaoneFlashAttention(
55
+ layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
56
+ )
57
+ else:
58
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
59
+
60
+ new_layer = ExaoneLayer(layer, new_self_attn)
61
+ new_layers.append(new_layer)
62
+ new_model = ExaoneModel(causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len)
63
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
64
+ return new_causal_lm
65
+
66
+
67
+ class ExaoneModel(DecoderOnlyModel):
68
+ def get_embedding(self) -> nn.Embedding:
69
+ return self._original_mod.wte
70
+
71
+ def get_last_layernorm(self) -> nn.LayerNorm:
72
+ return self._original_mod.ln_f
73
+
74
+
75
+ class ExaoneLayer(DecoderOnlyLayer):
76
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
77
+ return self._original_mod.ln_1
78
+
79
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
80
+ return self._original_mod.ln_2
81
+
82
+
83
+ class ExaoneAttention(DecoderOnlyAttention):
84
+ def __post_init__(self):
85
+ self.q_proj = self._original_mod.q_proj
86
+ self.k_proj = self._original_mod.k_proj
87
+ self.v_proj = self._original_mod.v_proj
88
+ self.o_proj = self._original_mod.out_proj
89
+
90
+
91
+ class ExaoneFlashAttention(DecoderOnlyFlashAttention):
92
+ def __post_init__(self):
93
+ self.q_proj = self._original_mod.q_proj
94
+ self.k_proj = self._original_mod.k_proj
95
+ self.v_proj = self._original_mod.v_proj
96
+ self.o_proj = self._original_mod.out_proj
@@ -21,10 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
24
27
  from ....utils import logging
25
28
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
26
29
  from .exaone_architecture import ExaoneForCausalLMWrapper
27
- from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
28
30
 
29
31
 
30
32
  logger = logging.get_logger(__name__)
@@ -45,7 +47,7 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
45
47
  """
46
48
 
47
49
  _decoder_wrapper_cls = ExaoneForCausalLMWrapper
48
- _original_cls = ExaoneForCausalLM
50
+ _hf_class = AutoModelForCausalLM
49
51
 
50
52
  @classmethod
51
53
  def from_pretrained(cls, *args, **kwargs):
@@ -21,113 +21,42 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Dict, List, Optional, Tuple, Union
25
-
26
- import torch
27
- from transformers.modeling_outputs import (
28
- BaseModelOutputWithPast,
29
- )
30
-
31
- from ...models.decoderonly import (
32
- DecoderOnlyDecoderLayer,
24
+ from typing import TYPE_CHECKING
25
+
26
+ from ...models.decoderonly.decoderonly_architecture import (
27
+ DecoderOnlyAttention,
28
+ DecoderOnlyFlashAttention,
29
+ DecoderOnlyForCausalLM,
30
+ DecoderOnlyLayer,
31
+ DecoderOnlyModel,
33
32
  DecoderOnlyWrapper,
34
- slice_and_unsqueeze_cos_sin,
35
33
  )
36
- from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
37
-
38
-
39
- class GemmaWrapper(DecoderOnlyWrapper):
40
- def get_forward_dict(self):
41
- forward_dict = {}
42
- forward_dict.update(
43
- {
44
- "wrapper": GemmaModel.forward,
45
- "model": DecoderOnlyDecoderLayer.forward,
46
- "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
47
- }
48
- )
49
- return forward_dict
50
-
51
-
52
- class GemmaModel:
53
- def forward(
54
- self,
55
- input_ids: torch.LongTensor = None,
56
- attention_mask: Optional[torch.Tensor] = None,
57
- position_ids: Optional[torch.LongTensor] = None,
58
- past_key_values: Optional[List[torch.FloatTensor]] = None,
59
- batch_ids: Optional[torch.LongTensor] = None,
60
- inputs_embeds: Optional[torch.FloatTensor] = None,
61
- use_cache: Optional[bool] = True,
62
- output_attentions: Optional[bool] = False,
63
- output_hidden_states: Optional[bool] = False,
64
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
65
- kvcache_partition_size: Optional[torch.Tensor] = None,
66
- forward_dict: Optional[Dict[str, classmethod]] = None,
67
- rotary_pos_emb=None,
68
- ) -> Union[Tuple, BaseModelOutputWithPast]:
69
- # retrieve input_ids and inputs_embeds
70
- if (input_ids is None) ^ (inputs_embeds is not None):
71
- raise ValueError(
72
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
73
- )
74
-
75
- # embed positions
76
- inputs_embeds = self.embed_tokens(input_ids)
77
- hidden_states = inputs_embeds
78
34
 
79
- ##### GEMMA change from llama#####
80
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
81
35
 
82
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
36
+ if TYPE_CHECKING:
37
+ from transformers import GemmaForCausalLM
83
38
 
84
- # get cos,sin vector
85
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
86
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
87
39
 
88
- # decoder layers
89
- all_hidden_states = () if output_hidden_states else None
90
- all_self_attns = () if output_attentions else None
91
-
92
- for layer_idx, decoder_layer in enumerate(self.layers):
93
- if output_hidden_states:
94
- all_hidden_states += (hidden_states,)
95
- layer_outputs = forward_dict["model"](
96
- decoder_layer,
97
- hidden_states,
98
- layer_idx,
99
- attention_mask=attention_mask,
100
- position_ids=position_ids,
101
- past_key_value=past_key_values,
102
- output_attentions=output_attentions,
103
- use_cache=use_cache,
104
- batch_ids=batch_ids,
105
- cos=cos,
106
- sin=sin,
107
- cache_pos_for_partitions=cache_pos_for_partitions,
108
- kvcache_partition_size=kvcache_partition_size,
109
- forward_dict=forward_dict,
110
- )
111
-
112
- hidden_states = layer_outputs[0]
113
-
114
- updated_cache = layer_outputs[2 if output_attentions else 1]
115
-
116
- if output_attentions:
117
- all_self_attns += (layer_outputs[1],)
118
-
119
- hidden_states = self.norm(hidden_states)
120
-
121
- # add hidden states from the last decoder layer
122
- if output_hidden_states:
123
- all_hidden_states += (hidden_states,)
124
-
125
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
126
- next_cache = updated_cache.to_legacy_cache()
127
-
128
- return BaseModelOutputWithPast(
129
- last_hidden_state=hidden_states,
130
- past_key_values=next_cache,
131
- hidden_states=all_hidden_states,
132
- attentions=all_self_attns,
133
- )
40
+ class GemmaWrapper(DecoderOnlyWrapper):
41
+ def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
42
+ new_layers = []
43
+ for layer in causal_lm.model.layers:
44
+ if self.attn_impl == "eager":
45
+ new_self_attn = DecoderOnlyAttention(layer.self_attn)
46
+ elif self.attn_impl == "flash_attn":
47
+ new_self_attn = DecoderOnlyFlashAttention(
48
+ layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
49
+ )
50
+ else:
51
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
52
+ new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
+ new_layers.append(new_layer)
54
+ new_model = GemmaModel(causal_lm.model, new_layers, partition_len=self.kvcache_partition_len)
55
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
56
+ return new_causal_lm
57
+
58
+
59
+ class GemmaModel(DecoderOnlyModel):
60
+ @property
61
+ def hidden_multiplier(self):
62
+ return self._original_mod.config.hidden_size**0.5
@@ -21,262 +21,74 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Dict, Optional, Tuple, Union
24
+ import math
25
+ from typing import TYPE_CHECKING, Tuple
25
26
 
26
27
  import torch
27
28
  import torch.nn as nn
28
- from transformers.modeling_outputs import BaseModelOutputWithPast
29
29
 
30
- from ...cache_utils import RebelDynamicCache_4D
30
+ from ..decoderonly.decoderonly_architecture import (
31
+ DecoderOnlyAttention,
32
+ DecoderOnlyForCausalLM,
33
+ DecoderOnlyLayer,
34
+ DecoderOnlyModel,
35
+ DecoderOnlyWrapper,
36
+ )
31
37
 
32
38
 
33
- class GPT2LMHeadModelWrapper(torch.nn.Module):
34
- def __init__(self, model, max_seq_len):
35
- super().__init__()
36
- self.model = model.transformer
37
- self.lm_head = model.lm_head
38
- self.config = model.config
39
- self.max_seq_len = max_seq_len
40
- self.forward_dict = self.get_forward_dict()
39
+ if TYPE_CHECKING:
40
+ from transformers import GPT2LMHeadModel
41
41
 
42
- def get_forward_dict(self):
43
- forward_dict = {
44
- "wrapper": _GPT2Model.forward,
45
- "model": _GPT2Block.forward,
46
- "decoder_layer": _GPT2Attention.forward,
47
- }
48
- return forward_dict
49
42
 
50
- def forward(
51
- self,
52
- input_ids,
53
- attention_mask,
54
- cache_position,
55
- batch_position,
56
- query_idx,
57
- *past_key_values,
58
- ):
59
- if input_ids.shape[1] == 1:
60
- rbln_batch_position = None
61
- else:
62
- rbln_batch_position = batch_position
43
+ class GPT2Wrapper(DecoderOnlyWrapper):
44
+ def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
45
+ if self.attn_impl != "eager":
46
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
47
+ new_layers = []
48
+ for layer in causal_lm.transformer.h:
49
+ new_self_attn = GPT2Attention(layer.attn)
50
+ new_layer = GPT2Layer(layer, new_self_attn)
51
+ new_layers.append(new_layer)
52
+ new_model = GPT2Model(causal_lm.transformer, new_layers)
53
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
54
+ return new_causal_lm
63
55
 
64
- # Formatting list of past_kv to DynamicCache class.
65
- past_key_value = RebelDynamicCache_4D.from_input_format(
66
- cache_position,
67
- self.config.n_layer,
68
- *past_key_values,
69
- )
70
56
 
71
- outputs = self.forward_dict["wrapper"](
72
- self.model,
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=cache_position,
76
- past_key_value=past_key_value,
77
- batch_ids=rbln_batch_position,
78
- forward_dict=self.forward_dict,
79
- # rotary_emb differenct from_llama
80
- )
57
+ class GPT2Model(DecoderOnlyModel):
58
+ def get_last_layernorm(self) -> nn.LayerNorm:
59
+ return self._original_mod.ln_f
81
60
 
82
- hidden_states = outputs[0]
83
- if batch_position >= 0:
84
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
85
- logits = self.lm_head(hidden_states)
61
+ def get_embedding(self) -> nn.Embedding:
62
+ return self._original_mod.wte
86
63
 
87
- output = (logits,) + outputs[1:]
64
+ def get_pos_embedding(self) -> nn.Embedding:
65
+ return self._original_mod.wpe
88
66
 
89
- return output, batch_position + query_idx
90
67
 
68
+ class GPT2Layer(DecoderOnlyLayer):
69
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
70
+ return self._original_mod.ln_1
91
71
 
92
- class _GPT2Model:
93
- def forward(
94
- self,
95
- input_ids: torch.LongTensor = None,
96
- attention_mask: Optional[torch.Tensor] = None,
97
- position_ids: Optional[torch.LongTensor] = None,
98
- past_key_value: Optional[RebelDynamicCache_4D] = None,
99
- batch_ids: Optional[torch.LongTensor] = None,
100
- forward_dict: Optional[Dict[str, classmethod]] = None,
101
- ) -> BaseModelOutputWithPast:
102
- b_size, q_len = input_ids.shape
103
- inputs_embeds = self.wte(input_ids)
72
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
73
+ return self._original_mod.ln_2
104
74
 
105
- if position_ids.shape[0] > 1:
106
- position_embeds = []
107
- for b_idx in range(b_size):
108
- position_embed = self.wpe(position_ids[b_idx])
109
- # position_embed = position_embed.dtype(inputs_embeds.dtype)
110
- position_embeds.append(position_embed)
111
75
 
112
- position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
113
- else:
114
- position_embeds = self.wpe(position_ids)
76
+ class GPT2Attention(DecoderOnlyAttention):
77
+ def __post_init__(self):
78
+ self.c_attn = self._original_mod.c_attn
79
+ self.o_proj = self._original_mod.c_proj
80
+ self.split_size = self._original_mod.split_size
115
81
 
116
- hidden_states = inputs_embeds + position_embeds
82
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
84
+ return query_states, key_states, value_states
117
85
 
118
- # GPT2Attention mask.
119
- # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
120
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
86
+ def get_attn_scale(self):
87
+ scale = 1.0
88
+ if self._original_mod.scale_attn_weights:
89
+ scale /= math.sqrt(self.head_dim)
121
90
 
122
- for layer_idx, block in enumerate(self.h):
123
- hidden_states, updated_cache = forward_dict["model"](
124
- block,
125
- hidden_states,
126
- layer_idx,
127
- attention_mask=attention_mask,
128
- past_key_value=past_key_value,
129
- position_ids=position_ids,
130
- batch_ids=batch_ids,
131
- forward_dict=forward_dict,
132
- )
91
+ if self._original_mod.scale_attn_by_inverse_layer_idx:
92
+ scale /= 1 + self.layer_idx
133
93
 
134
- hidden_states = self.ln_f(hidden_states)
135
- output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
136
- hidden_states = hidden_states.view(output_shape)
137
-
138
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
139
- next_cache = updated_cache.to_legacy_cache()
140
-
141
- return BaseModelOutputWithPast(
142
- last_hidden_state=hidden_states,
143
- past_key_values=next_cache,
144
- )
145
-
146
-
147
- class _GPT2Block:
148
- def forward(
149
- self,
150
- hidden_states: Optional[Tuple[torch.FloatTensor]],
151
- layer_idx: int,
152
- attention_mask: Optional[torch.FloatTensor] = None,
153
- position_ids: Optional[torch.LongTensor] = None,
154
- past_key_value: Optional[RebelDynamicCache_4D] = None,
155
- batch_ids: Optional[torch.LongTensor] = None,
156
- forward_dict: Optional[Dict[str, classmethod]] = None,
157
- **kwargs,
158
- ) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
159
- residual = hidden_states
160
- hidden_states = self.ln_1(hidden_states)
161
-
162
- hidden_states, k, v = forward_dict["decoder_layer"](
163
- self.attn,
164
- hidden_states=hidden_states,
165
- attention_mask=attention_mask,
166
- position_ids=position_ids,
167
- past_key_value=past_key_value,
168
- batch_index=batch_ids,
169
- )
170
- past_key_value.assign(k, v, layer_idx)
171
-
172
- # residual connection
173
- hidden_states = residual + hidden_states
174
-
175
- residual = hidden_states
176
- hidden_states = self.ln_2(hidden_states)
177
- hidden_states = self.mlp(hidden_states)
178
- hidden_states = residual + hidden_states
179
-
180
- return hidden_states, past_key_value
181
-
182
-
183
- class _GPT2Attention:
184
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
185
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
186
-
187
- if self.scale_attn_weights:
188
- attn_weights = attn_weights / torch.full(
189
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
190
- )
191
-
192
- # Layer-wise attention scaling
193
- if self.scale_attn_by_inverse_layer_idx:
194
- attn_weights = attn_weights / float(self.layer_idx + 1)
195
-
196
- # -------------------
197
- # Below are deleted since "where" op does not supported on RBLN graph.
198
- # -------------------
199
- # if not self.is_cross_attention:
200
- # # if only "normal" attention layer implements causal mask
201
- # query_length, key_length = query.size(-2), key.size(-2)
202
- # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
203
- # mask_value = torch.finfo(attn_weights.dtype).min
204
- # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
205
- # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
206
- # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
207
- # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
208
-
209
- # Apply the attention mask
210
- attn_weights.view(
211
- -1,
212
- )
213
- attn_weights = attn_weights + attention_mask
214
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
215
- attn_output = torch.matmul(attn_weights, value)
216
-
217
- return attn_output, attn_weights
218
-
219
- def forward(
220
- self,
221
- hidden_states: Optional[Tuple[torch.FloatTensor]],
222
- attention_mask: Optional[torch.FloatTensor] = None,
223
- past_key_value: Optional[RebelDynamicCache_4D] = None,
224
- batch_index: Optional[int] = None,
225
- **kwargs,
226
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
227
- bsz, q_len, _ = hidden_states.size()
228
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
229
-
230
- querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
231
- keys = self._split_heads(key, self.num_heads, self.head_dim)
232
- values = self._split_heads(value, self.num_heads, self.head_dim)
233
-
234
- # Decoder
235
- if (batch_index is None or batch_index == -1) and bsz > 1:
236
- all_keys = []
237
- all_values = []
238
- all_attn_output = []
239
-
240
- for b in range(bsz):
241
- query = querys[b].unsqueeze(0)
242
- attn_mask = attention_mask[b].unsqueeze(0)
243
- key = keys[b].unsqueeze(0)
244
- value = values[b].unsqueeze(0)
245
-
246
- key, value = past_key_value.update(
247
- key,
248
- value,
249
- self.layer_idx,
250
- b,
251
- )
252
-
253
- attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
254
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
255
-
256
- all_keys.append(key)
257
- all_values.append(value)
258
- all_attn_output.append(attn_output)
259
-
260
- keys = torch.cat(all_keys, dim=0)
261
- values = torch.cat(all_values, dim=0)
262
- attn_output = torch.cat(all_attn_output, dim=0)
263
-
264
- # Prefill
265
- else:
266
- if batch_index is None or batch_index == -1:
267
- batch_index = 0
268
-
269
- keys, values = past_key_value.update(
270
- keys,
271
- values,
272
- self.layer_idx,
273
- batch_index,
274
- read_first_step=True,
275
- )
276
-
277
- attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
278
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
279
-
280
- attn_output = self.c_proj(attn_output)
281
-
282
- return attn_output, keys, values
94
+ return scale
@@ -23,7 +23,7 @@
23
23
 
24
24
  from ....utils import logging
25
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
- from .gpt2_architecture import GPT2LMHeadModelWrapper
26
+ from .gpt2_architecture import GPT2Wrapper
27
27
 
28
28
 
29
29
  logger = logging.get_logger(__name__)
@@ -43,4 +43,5 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
43
43
 
44
44
  """
45
45
 
46
- _decoder_wrapper_cls = GPT2LMHeadModelWrapper
46
+ _decoder_wrapper_cls = GPT2Wrapper
47
+ _use_rotary_emb = False
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
25
 
27
26