optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,103 +21,42 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Dict, List, Optional, Tuple, Union
24
+ from typing import TYPE_CHECKING
25
25
 
26
- import torch
27
- from transformers.modeling_outputs import (
28
- BaseModelOutputWithPast,
29
- )
30
-
31
- from ...models.decoderonly import (
26
+ from ...models.decoderonly.decoderonly_architecture import (
32
27
  DecoderOnlyAttention,
33
- DecoderOnlyDecoderLayer,
28
+ DecoderOnlyFlashAttention,
29
+ DecoderOnlyForCausalLM,
30
+ DecoderOnlyLayer,
31
+ DecoderOnlyModel,
34
32
  DecoderOnlyWrapper,
35
- slice_and_unsqueeze_cos_sin,
36
33
  )
37
34
 
38
35
 
39
- class GemmaWrapper(DecoderOnlyWrapper):
40
- def get_forward_dict(self):
41
- forward_dict = {}
42
- forward_dict.update(
43
- {
44
- "wrapper": GemmaModel.forward,
45
- "model": DecoderOnlyDecoderLayer.forward,
46
- "decoder_layer": DecoderOnlyAttention.forward,
47
- }
48
- )
49
- return forward_dict
50
-
51
-
52
- class GemmaModel:
53
- def forward(
54
- self,
55
- input_ids: torch.LongTensor = None,
56
- attention_mask: Optional[torch.Tensor] = None,
57
- position_ids: Optional[torch.LongTensor] = None,
58
- past_key_values: Optional[List[torch.FloatTensor]] = None,
59
- batch_ids: Optional[torch.LongTensor] = None,
60
- inputs_embeds: Optional[torch.FloatTensor] = None,
61
- use_cache: Optional[bool] = True,
62
- output_attentions: Optional[bool] = False,
63
- output_hidden_states: Optional[bool] = False,
64
- forward_dict: Optional[Dict[str, classmethod]] = None,
65
- rotary_pos_emb=None,
66
- ) -> Union[Tuple, BaseModelOutputWithPast]:
67
- # embed positions
68
- inputs_embeds = self.embed_tokens(input_ids)
69
- hidden_states = inputs_embeds
70
-
71
- ##### GEMMA change from llama#####
72
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
73
-
74
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
75
-
76
- # get cos,sin vector
77
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
78
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
36
+ if TYPE_CHECKING:
37
+ from transformers import GemmaForCausalLM
79
38
 
80
- # decoder layers
81
- all_hidden_states = () if output_hidden_states else None
82
- all_self_attns = () if output_attentions else None
83
39
 
84
- for layer_idx, decoder_layer in enumerate(self.layers):
85
- if output_hidden_states:
86
- all_hidden_states += (hidden_states,)
87
- layer_outputs = forward_dict["model"](
88
- decoder_layer,
89
- hidden_states,
90
- layer_idx,
91
- attention_mask=attention_mask,
92
- position_ids=position_ids,
93
- past_key_value=past_key_values,
94
- output_attentions=output_attentions,
95
- use_cache=use_cache,
96
- batch_ids=batch_ids,
97
- cos=cos,
98
- sin=sin,
99
- forward_dict=forward_dict,
100
- )
101
-
102
- hidden_states = layer_outputs[0]
103
-
104
- updated_cache = layer_outputs[2 if output_attentions else 1]
105
-
106
- if output_attentions:
107
- all_self_attns += (layer_outputs[1],)
108
-
109
- hidden_states = self.norm(hidden_states)
110
-
111
- # add hidden states from the last decoder layer
112
- if output_hidden_states:
113
- all_hidden_states += (hidden_states,)
114
-
115
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
116
- next_cache = updated_cache.to_legacy_cache()
117
-
118
- return BaseModelOutputWithPast(
119
- last_hidden_state=hidden_states,
120
- past_key_values=next_cache,
121
- hidden_states=all_hidden_states,
122
- attentions=all_self_attns,
123
- )
40
+ class GemmaWrapper(DecoderOnlyWrapper):
41
+ def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
42
+ new_layers = []
43
+ for layer in causal_lm.model.layers:
44
+ if self.attn_impl == "eager":
45
+ new_self_attn = DecoderOnlyAttention(layer.self_attn)
46
+ elif self.attn_impl == "flash_attn":
47
+ new_self_attn = DecoderOnlyFlashAttention(
48
+ layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
49
+ )
50
+ else:
51
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
52
+ new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
+ new_layers.append(new_layer)
54
+ new_model = GemmaModel(causal_lm.model, new_layers)
55
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
56
+ return new_causal_lm
57
+
58
+
59
+ class GemmaModel(DecoderOnlyModel):
60
+ @property
61
+ def hidden_multiplier(self):
62
+ return self._original_mod.config.hidden_size**0.5
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import GemmaForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .gemma_architecture import GemmaWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Gemma Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return GemmaWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(GemmaForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = GemmaWrapper
@@ -21,262 +21,74 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Dict, Optional, Tuple, Union
24
+ from typing import TYPE_CHECKING, Tuple
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from transformers.modeling_outputs import BaseModelOutputWithPast
29
28
 
30
- from ...cache_utils import RebelDynamicCache_4D
29
+ from ..decoderonly.decoderonly_architecture import (
30
+ DecoderOnlyAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
33
+ DecoderOnlyModel,
34
+ DecoderOnlyWrapper,
35
+ )
31
36
 
32
37
 
33
- class GPT2LMHeadModelWrapper(torch.nn.Module):
34
- def __init__(self, model, max_seq_len):
35
- super().__init__()
36
- self.model = model.transformer
37
- self.lm_head = model.lm_head
38
- self.config = model.config
39
- self.max_seq_len = max_seq_len
40
- self.forward_dict = self.get_forward_dict()
38
+ if TYPE_CHECKING:
39
+ from transformers import GPT2LMHeadModel
41
40
 
42
- def get_forward_dict(self):
43
- forward_dict = {
44
- "wrapper": _GPT2Model.forward,
45
- "model": _GPT2Block.forward,
46
- "decoder_layer": _GPT2Attention.forward,
47
- }
48
- return forward_dict
49
41
 
50
- def forward(
51
- self,
52
- input_ids,
53
- attention_mask,
54
- cache_position,
55
- batch_position,
56
- query_idx,
57
- *past_key_values,
58
- ):
59
- if input_ids.shape[1] == 1:
60
- rbln_batch_position = None
61
- else:
62
- rbln_batch_position = batch_position
42
+ class GPT2Wrapper(DecoderOnlyWrapper):
43
+ def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
44
+ if self.attn_impl != "eager":
45
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
46
+ new_layers = []
47
+ for layer in causal_lm.transformer.h:
48
+ new_self_attn = GPT2Attention(layer.attn)
49
+ new_layer = GPT2Layer(layer, new_self_attn)
50
+ new_layers.append(new_layer)
51
+ new_model = GPT2Model(causal_lm.transformer, new_layers)
52
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
53
+ return new_causal_lm
63
54
 
64
- # Formatting list of past_kv to DynamicCache class.
65
- past_key_value = RebelDynamicCache_4D.from_input_format(
66
- cache_position,
67
- self.config.n_layer,
68
- *past_key_values,
69
- )
70
-
71
- outputs = self.forward_dict["wrapper"](
72
- self.model,
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=cache_position,
76
- past_key_value=past_key_value,
77
- batch_ids=rbln_batch_position,
78
- forward_dict=self.forward_dict,
79
- # rotary_emb differenct from_llama
80
- )
81
-
82
- hidden_states = outputs[0]
83
- if batch_position >= 0:
84
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
85
- logits = self.lm_head(hidden_states)
86
-
87
- output = (logits,) + outputs[1:]
88
55
 
89
- return output, batch_position + query_idx
56
+ class GPT2Model(DecoderOnlyModel):
57
+ mask_fmin = torch.finfo(torch.float32).min
90
58
 
59
+ def get_last_layernorm(self) -> nn.LayerNorm:
60
+ return self._original_mod.ln_f
91
61
 
92
- class _GPT2Model:
93
- def forward(
94
- self,
95
- input_ids: torch.LongTensor = None,
96
- attention_mask: Optional[torch.Tensor] = None,
97
- position_ids: Optional[torch.LongTensor] = None,
98
- past_key_value: Optional[RebelDynamicCache_4D] = None,
99
- batch_ids: Optional[torch.LongTensor] = None,
100
- forward_dict: Optional[Dict[str, classmethod]] = None,
101
- ) -> BaseModelOutputWithPast:
102
- b_size, q_len = input_ids.shape
103
- inputs_embeds = self.wte(input_ids)
62
+ def get_embedding(self) -> nn.Embedding:
63
+ return self._original_mod.wte
104
64
 
105
- if position_ids.shape[0] > 1:
106
- position_embeds = []
107
- for b_idx in range(b_size):
108
- position_embed = self.wpe(position_ids[b_idx])
109
- # position_embed = position_embed.dtype(inputs_embeds.dtype)
110
- position_embeds.append(position_embed)
65
+ def get_pos_embedding(self) -> nn.Embedding:
66
+ return self._original_mod.wpe
111
67
 
112
- position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
113
- else:
114
- position_embeds = self.wpe(position_ids)
115
68
 
116
- hidden_states = inputs_embeds + position_embeds
69
+ class GPT2Layer(DecoderOnlyLayer):
70
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
71
+ return self._original_mod.ln_1
117
72
 
118
- # GPT2Attention mask.
119
- # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
120
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
73
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
74
+ return self._original_mod.ln_2
121
75
 
122
- for layer_idx, block in enumerate(self.h):
123
- hidden_states, updated_cache = forward_dict["model"](
124
- block,
125
- hidden_states,
126
- layer_idx,
127
- attention_mask=attention_mask,
128
- past_key_value=past_key_value,
129
- position_ids=position_ids,
130
- batch_ids=batch_ids,
131
- forward_dict=forward_dict,
132
- )
133
-
134
- hidden_states = self.ln_f(hidden_states)
135
- output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
136
- hidden_states = hidden_states.view(output_shape)
137
-
138
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
139
- next_cache = updated_cache.to_legacy_cache()
140
-
141
- return BaseModelOutputWithPast(
142
- last_hidden_state=hidden_states,
143
- past_key_values=next_cache,
144
- )
145
76
 
77
+ class GPT2Attention(DecoderOnlyAttention):
78
+ def __post_init__(self):
79
+ self.c_attn = self._original_mod.c_attn
80
+ self.o_proj = self._original_mod.c_proj
81
+ self.split_size = self._original_mod.split_size
82
+ self.num_key_value_heads = self._original_mod.num_heads
146
83
 
147
- class _GPT2Block:
148
- def forward(
149
- self,
150
- hidden_states: Optional[Tuple[torch.FloatTensor]],
151
- layer_idx: int,
152
- attention_mask: Optional[torch.FloatTensor] = None,
153
- position_ids: Optional[torch.LongTensor] = None,
154
- past_key_value: Optional[RebelDynamicCache_4D] = None,
155
- batch_ids: Optional[torch.LongTensor] = None,
156
- forward_dict: Optional[Dict[str, classmethod]] = None,
157
- **kwargs,
158
- ) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
159
- residual = hidden_states
160
- hidden_states = self.ln_1(hidden_states)
84
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
86
+ return query_states, key_states, value_states
161
87
 
162
- hidden_states, k, v = forward_dict["decoder_layer"](
163
- self.attn,
164
- hidden_states=hidden_states,
165
- attention_mask=attention_mask,
166
- position_ids=position_ids,
167
- past_key_value=past_key_value,
168
- batch_index=batch_ids,
88
+ def rbln_attention(self, *args, **kwargs):
89
+ return super().rbln_attention(
90
+ *args,
91
+ **kwargs,
92
+ layer_idx=self.layer_idx,
93
+ scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
169
94
  )
170
- past_key_value.assign(k, v, layer_idx)
171
-
172
- # residual connection
173
- hidden_states = residual + hidden_states
174
-
175
- residual = hidden_states
176
- hidden_states = self.ln_2(hidden_states)
177
- hidden_states = self.mlp(hidden_states)
178
- hidden_states = residual + hidden_states
179
-
180
- return hidden_states, past_key_value
181
-
182
-
183
- class _GPT2Attention:
184
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
185
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
186
-
187
- if self.scale_attn_weights:
188
- attn_weights = attn_weights / torch.full(
189
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
190
- )
191
-
192
- # Layer-wise attention scaling
193
- if self.scale_attn_by_inverse_layer_idx:
194
- attn_weights = attn_weights / float(self.layer_idx + 1)
195
-
196
- # -------------------
197
- # Below are deleted since "where" op does not supported on RBLN graph.
198
- # -------------------
199
- # if not self.is_cross_attention:
200
- # # if only "normal" attention layer implements causal mask
201
- # query_length, key_length = query.size(-2), key.size(-2)
202
- # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
203
- # mask_value = torch.finfo(attn_weights.dtype).min
204
- # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
205
- # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
206
- # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
207
- # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
208
-
209
- # Apply the attention mask
210
- attn_weights.view(
211
- -1,
212
- )
213
- attn_weights = attn_weights + attention_mask
214
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
215
- attn_output = torch.matmul(attn_weights, value)
216
-
217
- return attn_output, attn_weights
218
-
219
- def forward(
220
- self,
221
- hidden_states: Optional[Tuple[torch.FloatTensor]],
222
- attention_mask: Optional[torch.FloatTensor] = None,
223
- past_key_value: Optional[RebelDynamicCache_4D] = None,
224
- batch_index: Optional[int] = None,
225
- **kwargs,
226
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
227
- bsz, q_len, _ = hidden_states.size()
228
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
229
-
230
- querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
231
- keys = self._split_heads(key, self.num_heads, self.head_dim)
232
- values = self._split_heads(value, self.num_heads, self.head_dim)
233
-
234
- # Decoder
235
- if (batch_index is None or batch_index == -1) and bsz > 1:
236
- all_keys = []
237
- all_values = []
238
- all_attn_output = []
239
-
240
- for b in range(bsz):
241
- query = querys[b].unsqueeze(0)
242
- attn_mask = attention_mask[b].unsqueeze(0)
243
- key = keys[b].unsqueeze(0)
244
- value = values[b].unsqueeze(0)
245
-
246
- key, value = past_key_value.update(
247
- key,
248
- value,
249
- self.layer_idx,
250
- b,
251
- )
252
-
253
- attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
254
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
255
-
256
- all_keys.append(key)
257
- all_values.append(value)
258
- all_attn_output.append(attn_output)
259
-
260
- keys = torch.cat(all_keys, dim=0)
261
- values = torch.cat(all_values, dim=0)
262
- attn_output = torch.cat(all_attn_output, dim=0)
263
-
264
- # Prefill
265
- else:
266
- if batch_index is None or batch_index == -1:
267
- batch_index = 0
268
-
269
- keys, values = past_key_value.update(
270
- keys,
271
- values,
272
- self.layer_idx,
273
- batch_index,
274
- read_first_step=True,
275
- )
276
-
277
- attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
278
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
279
-
280
- attn_output = self.c_proj(attn_output)
281
-
282
- return attn_output, keys, values
@@ -21,20 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import GPT2LMHeadModel
29
-
30
- from ....modeling_config import RBLNConfig
24
+ from ....utils import logging
31
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
32
- from .gpt2_architecture import GPT2LMHeadModelWrapper
26
+ from .gpt2_architecture import GPT2Wrapper # GPT2LMHeadModelWrapper
33
27
 
34
28
 
35
- logger = logging.getLogger(__name__)
36
- if TYPE_CHECKING:
37
- from transformers import PreTrainedModel
29
+ logger = logging.get_logger(__name__)
38
30
 
39
31
 
40
32
  class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -42,7 +34,7 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
42
34
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
43
35
  embeddings).
44
36
 
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
37
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
46
38
  library implements for all its model.
47
39
 
48
40
  It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
@@ -51,22 +43,5 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
51
43
 
52
44
  """
53
45
 
54
- @classmethod
55
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
- return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
58
-
59
- def __getattr__(self, __name: str) -> Any:
60
- """This is the key method to implement RBLN-GPT2.
61
-
62
- Returns:
63
- Any: GPT2's corresponding method
64
- """
65
-
66
- def redirect(func):
67
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
68
-
69
- val = getattr(GPT2LMHeadModel, __name)
70
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
71
- return redirect(val)
72
- return val
46
+ _decoder_wrapper_cls = GPT2Wrapper
47
+ _use_rotary_emb = False
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import LlamaForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .llama_architecture import LlamaWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return LlamaWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(LlamaForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = LlamaWrapper