optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import PhiForCausalLM
29
-
30
- from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
24
+ from ....utils import logging
25
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .phi_architecture import PhiWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Phi Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
@@ -50,20 +40,4 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return PhiWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(PhiForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(
65
- inspect.signature(val).parameters
66
- ):
67
- return redirect(val)
68
-
69
- return val
43
+ _decoder_wrapper_cls = PhiWrapper
@@ -21,386 +21,102 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import math
25
- from typing import Dict, Optional, Tuple
24
+ from typing import TYPE_CHECKING, Optional, Tuple
26
25
 
27
26
  import torch
28
- import torch.nn as nn
29
- from transformers.modeling_outputs import (
30
- BaseModelOutputWithPast,
31
- )
27
+ from transformers import PhiForCausalLM
32
28
 
33
- from ...cache_utils import RebelDynamicCache
34
- from ..decoderonly import (
29
+ from ..decoderonly.decoderonly_architecture import (
30
+ DecoderOnlyAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
33
+ DecoderOnlyModel,
35
34
  DecoderOnlyWrapper,
36
- DynamicNTKScalingRotaryEmbedding,
37
- LinearScalingRotaryEmbedding,
38
- RotaryEmbedding,
39
- apply_rotary_pos_emb,
40
- slice_and_unsqueeze_cos_sin,
35
+ apply_rotary_pos_emb_partial,
41
36
  )
42
37
 
43
38
 
44
- class PhiWrapper(DecoderOnlyWrapper):
45
- def _init_rope(self):
46
- if self.rope_scaling is None:
47
- rotary_emb = RotaryEmbedding(
48
- int(self.config.partial_rotary_factor * self.head_dim),
49
- max_position_embeddings=self.max_position_embeddings,
50
- base=self.config.rope_theta,
51
- )
52
- else:
53
- scaling_type = self.rope_scaling["type"]
54
- scaling_factor = self.rope_scaling["factor"]
55
- if scaling_type == "linear":
56
- rotary_emb = LinearScalingRotaryEmbedding(
57
- int(self.config.partial_rotary_factor * self.head_dim),
58
- max_position_embeddings=self.max_position_embeddings,
59
- scaling_factor=scaling_factor,
60
- base=self.config.rope_theta,
61
- max_seq_len=self.max_seq_len,
62
- )
63
- elif scaling_type == "dynamic":
64
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
65
- int(self.config.partial_rotary_factor * self.head_dim),
66
- max_position_embeddings=self.max_position_embeddings,
67
- scaling_factor=scaling_factor,
68
- base=self.config.rope_theta,
69
- max_seq_len=self.max_seq_len,
70
- )
71
- else:
72
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
73
-
74
- return rotary_emb
75
-
76
- def get_forward_dict(self):
77
- forward_dict = {}
78
- forward_dict.update(
79
- {
80
- "wrapper": PhiModel.forward,
81
- "model": PhiDecoderLayer.forward,
82
- "decoder_layer": PhiAttention.forward,
83
- }
84
- )
85
- return forward_dict
86
-
39
+ if TYPE_CHECKING:
40
+ from transformers import PhiForCausalLM
87
41
 
88
- class PhiAttention:
89
- def forward(
90
- self,
91
- hidden_states: torch.Tensor,
92
- attention_mask: Optional[torch.Tensor] = None,
93
- past_key_value: Optional[RebelDynamicCache] = None,
94
- batch_index: Optional[int] = None,
95
- output_attentions: bool = False,
96
- cos: Optional[torch.Tensor] = None,
97
- sin: Optional[torch.Tensor] = None,
98
- rotary_pos_emb=None,
99
- **kwargs,
100
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
101
- bsz, q_len, _ = hidden_states.size()
102
42
 
43
+ class PhiWrapper(DecoderOnlyWrapper):
44
+ def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
45
+ new_layers = []
46
+ for layer in causal_lm.model.layers:
47
+ if self.attn_impl == "eager":
48
+ new_self_attn = PhiAttention(layer.self_attn)
49
+ elif self.attn_impl == "flash_attn":
50
+ raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
51
+ else:
52
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
53
+ new_layer = PhiLayer(layer, new_self_attn)
54
+ new_layers.append(new_layer)
55
+ new_model = PhiModel(causal_lm.model, new_layers)
56
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
57
+ return new_causal_lm
58
+
59
+
60
+ class PhiAttention(DecoderOnlyAttention):
61
+ def __post_init__(self):
62
+ self.q_proj = self._original_mod.q_proj
63
+ self.k_proj = self._original_mod.k_proj
64
+ self.v_proj = self._original_mod.v_proj
65
+ self.o_proj = self._original_mod.dense
66
+ self.qk_layernorm = self._original_mod.qk_layernorm
67
+ self.rotary_ndims = self._original_mod.rotary_ndims
68
+ self.num_key_value_heads = self.num_heads
69
+
70
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
71
  query_states = self.q_proj(hidden_states)
104
72
  key_states = self.k_proj(hidden_states)
105
73
  value_states = self.v_proj(hidden_states)
106
74
 
107
75
  if self.qk_layernorm:
108
- query_states = self.q_layernorm(query_states)
109
- key_states = self.k_layernorm(key_states)
110
-
111
- query_states = query_states.view(
112
- bsz, q_len, self.num_heads, self.head_dim
113
- ).transpose(1, 2)
114
- key_states = key_states.view(
115
- bsz, q_len, self.num_key_value_heads, self.head_dim
116
- ).transpose(1, 2)
117
- value_states = value_states.view(
118
- bsz, q_len, self.num_key_value_heads, self.head_dim
119
- ).transpose(1, 2)
120
-
121
- # Partial rotary embedding
122
- query_rot, query_pass = (
123
- query_states[..., : rotary_pos_emb.dim],
124
- query_states[..., rotary_pos_emb.dim :],
125
- )
126
- key_rot, key_pass = (
127
- key_states[..., : rotary_pos_emb.dim],
128
- key_states[..., rotary_pos_emb.dim :],
129
- )
130
-
131
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
132
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
133
-
134
- # [batch_size, seq_length, num_heads, head_dim]
135
- query_states = torch.cat((query_rot, query_pass), dim=-1)
136
- key_states = torch.cat((key_rot, key_pass), dim=-1)
137
-
138
- # Decoder
139
- if (batch_index is None or batch_index == -1) and bsz > 1:
140
- all_key_states = []
141
- all_value_states = []
142
- all_attn_output = []
143
-
144
- for b in range(bsz):
145
- query_state = query_states[b].unsqueeze(0)
146
- attn_mask = attention_mask[b].unsqueeze(0)
147
- key_state = key_states[b].unsqueeze(0)
148
- value_state = value_states[b].unsqueeze(0)
149
-
150
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
151
- key_state = key_state.unsqueeze(2)
152
- value_state = value_state.unsqueeze(2)
153
- attn_mask = attn_mask.unsqueeze(2)
154
-
155
- query_state = query_state.view(
156
- 1,
157
- self.num_key_value_heads,
158
- self.num_heads // self.num_key_value_heads,
159
- q_len,
160
- self.head_dim,
161
- )
162
-
163
- key_state, value_state = past_key_value.update(
164
- key_state,
165
- value_state,
166
- self.layer_idx,
167
- b,
168
- )
169
-
170
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
171
- attn_weights = torch.matmul(
172
- query_state.to(torch.float32),
173
- key_state.to(torch.float32).transpose(3, 4),
174
- ) / math.sqrt(self.head_dim)
175
- attn_weights = attn_weights + attn_mask
176
-
177
- # upcast attention to fp32
178
- attn_weights = nn.functional.softmax(
179
- attn_weights, dim=-1, dtype=torch.float32
180
- ).to(query_states.dtype)
181
- attn_weights = nn.functional.dropout(
182
- attn_weights, p=self.attention_dropout, training=self.training
183
- )
184
- attn_output = torch.matmul(attn_weights, value_state)
76
+ query_states = self._original_mod.q_layernorm(query_states)
77
+ key_states = self._original_mod.k_layernorm(key_states)
185
78
 
186
- # reshape for removing repeat_kv
187
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
188
- attn_output = attn_output.transpose(1, 2).contiguous()
189
- attn_output = attn_output.reshape(
190
- 1, q_len, self.num_heads * self.head_dim
191
- )
79
+ return query_states, key_states, value_states
192
80
 
193
- all_key_states.append(key_state)
194
- all_value_states.append(value_state)
195
- all_attn_output.append(attn_output)
81
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
82
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
196
83
 
197
- key_states = torch.cat(all_key_states, dim=0)
198
- value_states = torch.cat(all_value_states, dim=0)
199
- attn_output = torch.cat(all_attn_output, dim=0)
200
- else:
201
- if batch_index is None or batch_index == -1:
202
- batch_index = 0
203
84
 
204
- # reshape for removing repeat_kv
205
- key_states = key_states.unsqueeze(2)
206
- value_states = value_states.unsqueeze(2)
207
- attention_mask = attention_mask.unsqueeze(2)
208
- query_states = query_states.view(
209
- 1,
210
- self.num_key_value_heads,
211
- self.num_heads // self.num_key_value_heads,
212
- q_len,
213
- self.head_dim,
214
- )
85
+ class PhiLayer(DecoderOnlyLayer):
86
+ def get_post_attention_layernorm(self):
87
+ raise NotImplementedError
215
88
 
216
- key_states, value_states = past_key_value.update(
217
- key_states,
218
- value_states,
219
- self.layer_idx,
220
- batch_index,
221
- read_first_step=True,
222
- )
223
-
224
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
225
- attn_weights = torch.matmul(
226
- query_states.to(torch.float32),
227
- key_states.to(torch.float32).transpose(3, 4),
228
- ) / math.sqrt(self.head_dim)
229
- attn_weights = attn_weights + attention_mask
230
-
231
- # upcast attention to fp32
232
- attn_weights = torch.nn.functional.softmax(
233
- attn_weights, dim=-1, dtype=torch.float32
234
- ).to(value_states.dtype)
235
- attn_weights = torch.nn.functional.dropout(
236
- attn_weights, p=self.attention_dropout, training=self.training
237
- )
238
- attn_output = torch.matmul(attn_weights, value_states)
239
-
240
- # reshape for removing repeat_kv
241
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
242
- attn_output = attn_output.transpose(1, 2).contiguous()
243
- attn_output = attn_output.reshape(
244
- bsz, q_len, self.num_heads * self.head_dim
245
- )
246
-
247
- attn_output = self.dense(attn_output)
248
-
249
- if not output_attentions:
250
- attn_weights = None
251
-
252
- return attn_output, attn_weights, key_states, value_states
253
-
254
-
255
- class PhiDecoderLayer:
256
89
  def forward(
257
90
  self,
258
91
  hidden_states: torch.Tensor,
259
- layer_idx: int,
260
- attention_mask: Optional[torch.Tensor] = None,
261
- position_ids: Optional[torch.LongTensor] = None,
262
- past_key_value: Optional[RebelDynamicCache] = None,
263
- output_attentions: Optional[bool] = None,
264
- use_cache: Optional[bool] = None,
265
- batch_ids: Optional[torch.LongTensor] = None,
92
+ attention_mask: torch.Tensor,
93
+ current_steps: torch.LongTensor,
94
+ batch_position: torch.Tensor,
95
+ past_key_values: Tuple[Tuple[torch.Tensor]],
266
96
  cos: Optional[torch.Tensor] = None,
267
97
  sin: Optional[torch.Tensor] = None,
268
- rotary_pos_emb=None,
269
- forward_dict: Optional[Dict[str, classmethod]] = None,
270
- **kwargs,
271
- ) -> Tuple[
272
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
273
- ]:
274
- """
275
- Args:
276
- hidden_states (`torch.FloatTensor`):
277
- input to the layer of shape `(batch, seq_len, embed_dim)`
278
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
279
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
280
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
281
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
282
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
283
- output_attentions (`bool`, *optional*):
284
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
285
- returned tensors for more detail.
286
- use_cache (`bool`, *optional*):
287
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
288
- (see `past_key_values`).
289
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
290
- """
291
-
98
+ ):
292
99
  residual = hidden_states
293
100
 
294
- hidden_states = self.input_layernorm(hidden_states)
101
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
295
102
 
296
- # Self Attention
297
- attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
298
- "decoder_layer"
299
- ](
300
- self.self_attn,
103
+ attn_outputs, present_key_values = self.self_attn(
301
104
  hidden_states=hidden_states,
302
105
  attention_mask=attention_mask,
303
- position_ids=position_ids,
304
- past_key_value=past_key_value,
305
- output_attentions=output_attentions,
306
- batch_index=batch_ids,
307
- use_cache=use_cache,
106
+ current_steps=current_steps,
107
+ batch_position=batch_position,
108
+ past_key_values=past_key_values,
308
109
  cos=cos,
309
110
  sin=sin,
310
- rotary_pos_emb=rotary_pos_emb,
311
- **kwargs,
312
111
  )
313
- past_key_value.assign(key_states, value_states, layer_idx)
314
112
 
315
- attn_outputs = self.resid_dropout(attn_outputs)
113
+ feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
316
114
 
317
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
318
115
  hidden_states = attn_outputs + feed_forward_hidden_states + residual
319
- outputs = (hidden_states,)
320
-
321
- if output_attentions:
322
- outputs += (self_attn_weights,)
323
-
324
- if use_cache:
325
- outputs += (past_key_value,)
326
-
327
- return outputs
328
-
329
116
 
330
- class PhiModel:
331
- def forward(
332
- self,
333
- input_ids: torch.LongTensor = None,
334
- attention_mask: Optional[torch.Tensor] = None,
335
- position_ids: Optional[torch.LongTensor] = None,
336
- past_key_values: Optional[RebelDynamicCache] = None,
337
- batch_ids: Optional[torch.LongTensor] = None,
338
- inputs_embeds: Optional[torch.FloatTensor] = None,
339
- use_cache: Optional[bool] = True,
340
- output_attentions: Optional[bool] = False,
341
- output_hidden_states: Optional[bool] = False,
342
- forward_dict: Optional[Dict[str, classmethod]] = None,
343
- rotary_pos_emb=None,
344
- ) -> BaseModelOutputWithPast:
345
- # retrieve input_ids and inputs_embeds
346
- if (input_ids is None) ^ (inputs_embeds is not None):
347
- raise ValueError(
348
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
349
- )
350
-
351
- # embed positions
352
- if inputs_embeds is None:
353
- inputs_embeds = self.embed_tokens(input_ids)
354
-
355
- hidden_states = inputs_embeds
356
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
357
-
358
- # get cos,sin vector
359
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
360
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
361
-
362
- # decoder layers
363
- all_hidden_states = () if output_hidden_states else None
364
- all_self_attns = () if output_attentions else None
365
-
366
- for layer_idx, decoder_layer in enumerate(self.layers):
367
- if output_hidden_states:
368
- all_hidden_states += (hidden_states,)
369
- layer_outputs = forward_dict["model"](
370
- decoder_layer,
371
- hidden_states,
372
- layer_idx,
373
- attention_mask=attention_mask,
374
- position_ids=position_ids,
375
- past_key_value=past_key_values,
376
- output_attentions=output_attentions,
377
- use_cache=use_cache,
378
- batch_ids=batch_ids,
379
- cos=cos,
380
- sin=sin,
381
- rotary_pos_emb=rotary_pos_emb,
382
- forward_dict=forward_dict,
383
- )
117
+ return hidden_states, present_key_values
384
118
 
385
- hidden_states = layer_outputs[0]
386
119
 
387
- updated_cache = layer_outputs[2 if output_attentions else 1]
388
-
389
- if output_attentions:
390
- all_self_attns += (layer_outputs[1],)
391
-
392
- hidden_states = self.final_layernorm(hidden_states)
393
-
394
- # add hidden states from the last decoder layer
395
- if output_hidden_states:
396
- all_hidden_states += (hidden_states,)
397
-
398
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
399
- next_cache = updated_cache.to_legacy_cache()
400
-
401
- return BaseModelOutputWithPast(
402
- last_hidden_state=hidden_states,
403
- past_key_values=next_cache,
404
- hidden_states=all_hidden_states,
405
- attentions=all_self_attns,
406
- )
120
+ class PhiModel(DecoderOnlyModel):
121
+ def get_last_layernorm(self):
122
+ return self._original_mod.final_layernorm
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import Qwen2ForCausalLM
29
-
30
- from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
24
+ from ....utils import logging
25
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .qwen2_architecture import QWEN2Wrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return QWEN2Wrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(Qwen2ForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = QWEN2Wrapper
@@ -31,7 +31,7 @@ import torch # noqa: F401
31
31
  from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
32
32
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
33
33
 
34
- from ....modeling_base import RBLNModel
34
+ from ....modeling import RBLNModel
35
35
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
37
 
@@ -346,51 +346,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
346
346
 
347
347
  return Seq2SeqLMOutput(logits=lm_logits)
348
348
 
349
- def vllm_forward(
350
- self,
351
- input_ids: torch.LongTensor = None,
352
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
353
- batch_idx: Optional[torch.LongTensor] = None,
354
- enc_lengths: List[int] = None, # vllm return current attention_mask length
355
- **kwargs,
356
- ) -> Tuple[torch.FloatTensor]:
357
- # When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
358
- # encoder
359
- if batch_idx is not None:
360
- enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
361
- enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
362
- padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
363
- input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
364
- _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
365
- logits = torch.zeros(1, 1, self.config.vocab_size + 100)
366
- logits[0][0][-1] = 1
367
- # decoder
368
- else:
369
- input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
370
- cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
371
-
372
- enc_attention_mask = torch.zeros(
373
- self.rbln_config.model_cfg["batch_size"],
374
- self.rbln_config.model_cfg["enc_max_seq_len"],
375
- dtype=torch.float32,
376
- )
377
- dec_attention_mask = torch.zeros(
378
- self.rbln_config.model_cfg["batch_size"],
379
- self.rbln_config.model_cfg["dec_max_seq_len"],
380
- dtype=torch.float32,
381
- )
382
- for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
383
- enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
384
-
385
- logits = self._forward_decoder(
386
- attention_mask=enc_attention_mask,
387
- decoder_input_ids=input_ids,
388
- decoder_attention_mask=dec_attention_mask,
389
- cache_position=cache_position,
390
- ).logits
391
-
392
- return Seq2SeqLMOutput(logits=logits)
393
-
394
349
  def _prepare_encoder_decoder_kwargs_for_generation(
395
350
  self,
396
351
  inputs_tensor: torch.Tensor,
@@ -21,5 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling_t5 import RBLNT5ForConditionalGeneration
24
+ from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
25
25
  from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper