optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -21,302 +21,101 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import math
25
- from typing import Dict, Optional, Tuple
24
+ from typing import TYPE_CHECKING, Optional, Tuple
26
25
 
27
26
  import torch
28
- import torch.nn as nn
29
- from transformers.modeling_outputs import (
30
- BaseModelOutputWithPast,
31
- )
27
+ from transformers import PhiForCausalLM
32
28
 
33
- from ...cache_utils import RebelDynamicCache
34
- from ..decoderonly import (
29
+ from ..decoderonly.decoderonly_architecture import (
30
+ DecoderOnlyAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
33
+ DecoderOnlyModel,
35
34
  DecoderOnlyWrapper,
36
- apply_rotary_pos_emb,
37
- slice_and_unsqueeze_cos_sin,
35
+ apply_rotary_pos_emb_partial,
38
36
  )
39
37
 
40
38
 
41
- class PhiWrapper(DecoderOnlyWrapper):
42
- def get_forward_dict(self):
43
- forward_dict = {}
44
- forward_dict.update(
45
- {
46
- "wrapper": PhiModel.forward,
47
- "model": PhiDecoderLayer.forward,
48
- "decoder_layer": PhiAttention.forward,
49
- }
50
- )
51
- return forward_dict
52
-
53
-
54
- class PhiAttention:
55
- def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
56
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
57
- key_state = key_state.unsqueeze(2)
58
- value_state = value_state.unsqueeze(2)
59
- attn_mask = attn_mask.unsqueeze(2)
60
-
61
- query_state = query_state.view(
62
- 1,
63
- self.num_key_value_heads,
64
- self.num_heads // self.num_key_value_heads,
65
- -1,
66
- self.head_dim,
67
- )
68
-
69
- key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
70
-
71
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
72
- attn_weights = torch.matmul(
73
- query_state.to(torch.float32),
74
- key_state.to(torch.float32).transpose(3, 4),
75
- ) / math.sqrt(self.head_dim)
76
- attn_weights = attn_weights + attn_mask
77
-
78
- # upcast attention to fp32
79
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
80
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
81
- attn_output = torch.matmul(attn_weights, value_state)
82
-
83
- # reshape for removing repeat_kv
84
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
85
- attn_output = attn_output.transpose(1, 2).contiguous()
86
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
39
+ if TYPE_CHECKING:
40
+ from transformers import PhiForCausalLM
87
41
 
88
- return attn_output, key_state, value_state
89
-
90
- def forward(
91
- self,
92
- hidden_states: torch.Tensor,
93
- attention_mask: Optional[torch.Tensor] = None,
94
- past_key_value: Optional[RebelDynamicCache] = None,
95
- batch_index: Optional[int] = None,
96
- output_attentions: bool = False,
97
- cos: Optional[torch.Tensor] = None,
98
- sin: Optional[torch.Tensor] = None,
99
- **kwargs,
100
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
101
- bsz, q_len, _ = hidden_states.size()
102
42
 
43
+ class PhiWrapper(DecoderOnlyWrapper):
44
+ def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
45
+ new_layers = []
46
+ for layer in causal_lm.model.layers:
47
+ if self.attn_impl == "eager":
48
+ new_self_attn = PhiAttention(layer.self_attn)
49
+ elif self.attn_impl == "flash_attn":
50
+ raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
51
+ else:
52
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
53
+ new_layer = PhiLayer(layer, new_self_attn)
54
+ new_layers.append(new_layer)
55
+ new_model = PhiModel(causal_lm.model, new_layers)
56
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
57
+ return new_causal_lm
58
+
59
+
60
+ class PhiAttention(DecoderOnlyAttention):
61
+ def __post_init__(self):
62
+ self.q_proj = self._original_mod.q_proj
63
+ self.k_proj = self._original_mod.k_proj
64
+ self.v_proj = self._original_mod.v_proj
65
+ self.o_proj = self._original_mod.dense
66
+ self.qk_layernorm = self._original_mod.qk_layernorm
67
+ self.rotary_ndims = self._original_mod.rotary_ndims
68
+
69
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
70
  query_states = self.q_proj(hidden_states)
104
71
  key_states = self.k_proj(hidden_states)
105
72
  value_states = self.v_proj(hidden_states)
106
73
 
107
74
  if self.qk_layernorm:
108
- query_states = self.q_layernorm(query_states)
109
- key_states = self.k_layernorm(key_states)
110
-
111
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
112
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
114
-
115
- # Partial rotary embedding
116
- query_rot, query_pass = (
117
- query_states[..., : self.rotary_ndims],
118
- query_states[..., self.rotary_ndims :],
119
- )
120
- key_rot, key_pass = (
121
- key_states[..., : self.rotary_ndims],
122
- key_states[..., self.rotary_ndims :],
123
- )
75
+ query_states = self._original_mod.q_layernorm(query_states)
76
+ key_states = self._original_mod.k_layernorm(key_states)
124
77
 
125
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
126
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
78
+ return query_states, key_states, value_states
127
79
 
128
- # [batch_size, seq_length, num_heads, head_dim]
129
- query_states = torch.cat((query_rot, query_pass), dim=-1)
130
- key_states = torch.cat((key_rot, key_pass), dim=-1)
80
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
81
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
131
82
 
132
- # Decoder (bsz > 1)
133
- if bsz > 1:
134
- iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
135
- for b in range(bsz):
136
- attn_output, key_state, value_state = PhiAttention._attn(
137
- self,
138
- query_states[b].unsqueeze(0),
139
- key_states[b].unsqueeze(0),
140
- value_states[b].unsqueeze(0),
141
- attention_mask[b].unsqueeze(0),
142
- past_key_value,
143
- batch_idx=b,
144
- is_prefill=False,
145
- )
146
- iterate_results["key_states"].append(key_state)
147
- iterate_results["value_states"].append(value_state)
148
- iterate_results["attn_output"].append(attn_output)
149
83
 
150
- key_states = torch.cat(iterate_results["key_states"], dim=0)
151
- value_states = torch.cat(iterate_results["value_states"], dim=0)
152
- attn_output = torch.cat(iterate_results["attn_output"], dim=0)
153
- # Prefill & Decoder (bsz == 1)
154
- else:
155
- attn_output, key_states, value_states = PhiAttention._attn(
156
- self,
157
- query_states,
158
- key_states,
159
- value_states,
160
- attention_mask,
161
- past_key_value,
162
- batch_idx=batch_index,
163
- is_prefill=True,
164
- )
84
+ class PhiLayer(DecoderOnlyLayer):
85
+ def get_post_attention_layernorm(self):
86
+ raise NotImplementedError
165
87
 
166
- attn_output = self.dense(attn_output)
167
-
168
- if not output_attentions:
169
- attn_weights = None
170
-
171
- return attn_output, attn_weights, key_states, value_states
172
-
173
-
174
- class PhiDecoderLayer:
175
88
  def forward(
176
89
  self,
177
90
  hidden_states: torch.Tensor,
178
- layer_idx: int,
179
- attention_mask: Optional[torch.Tensor] = None,
180
- position_ids: Optional[torch.LongTensor] = None,
181
- past_key_value: Optional[RebelDynamicCache] = None,
182
- output_attentions: Optional[bool] = None,
183
- use_cache: Optional[bool] = None,
184
- batch_ids: Optional[torch.LongTensor] = None,
91
+ attention_mask: torch.Tensor,
92
+ seq_positions: torch.LongTensor,
93
+ batch_position: torch.Tensor,
94
+ past_key_values: Tuple[Tuple[torch.Tensor]],
185
95
  cos: Optional[torch.Tensor] = None,
186
96
  sin: Optional[torch.Tensor] = None,
187
- forward_dict: Optional[Dict[str, classmethod]] = None,
188
- **kwargs,
189
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
190
- """
191
- Args:
192
- hidden_states (`torch.FloatTensor`):
193
- input to the layer of shape `(batch, seq_len, embed_dim)`
194
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
195
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
196
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
197
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
198
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
199
- output_attentions (`bool`, *optional*):
200
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
201
- returned tensors for more detail.
202
- use_cache (`bool`, *optional*):
203
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
204
- (see `past_key_values`).
205
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
206
- """
207
-
97
+ ):
208
98
  residual = hidden_states
209
99
 
210
- hidden_states = self.input_layernorm(hidden_states)
100
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
211
101
 
212
- # Self Attention
213
- attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
214
- self.self_attn,
102
+ attn_outputs, present_key_values = self.self_attn(
215
103
  hidden_states=hidden_states,
216
104
  attention_mask=attention_mask,
217
- position_ids=position_ids,
218
- past_key_value=past_key_value,
219
- output_attentions=output_attentions,
220
- batch_index=batch_ids,
221
- use_cache=use_cache,
105
+ seq_positions=seq_positions,
106
+ batch_position=batch_position,
107
+ past_key_values=past_key_values,
222
108
  cos=cos,
223
109
  sin=sin,
224
- **kwargs,
225
110
  )
226
- past_key_value.assign(key_states, value_states, layer_idx)
227
111
 
228
- attn_outputs = self.resid_dropout(attn_outputs)
112
+ feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
229
113
 
230
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
231
114
  hidden_states = attn_outputs + feed_forward_hidden_states + residual
232
- outputs = (hidden_states,)
233
-
234
- if output_attentions:
235
- outputs += (self_attn_weights,)
236
-
237
- if use_cache:
238
- outputs += (past_key_value,)
239
-
240
- return outputs
241
-
242
-
243
- class PhiModel:
244
- def forward(
245
- self,
246
- input_ids: torch.LongTensor = None,
247
- attention_mask: Optional[torch.Tensor] = None,
248
- position_ids: Optional[torch.LongTensor] = None,
249
- past_key_values: Optional[RebelDynamicCache] = None,
250
- batch_ids: Optional[torch.LongTensor] = None,
251
- inputs_embeds: Optional[torch.FloatTensor] = None,
252
- use_cache: Optional[bool] = True,
253
- output_attentions: Optional[bool] = False,
254
- output_hidden_states: Optional[bool] = False,
255
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
256
- kvcache_partition_size: Optional[torch.Tensor] = None,
257
- forward_dict: Optional[Dict[str, classmethod]] = None,
258
- rotary_pos_emb=None,
259
- ) -> BaseModelOutputWithPast:
260
- # retrieve input_ids and inputs_embeds
261
- if (input_ids is None) ^ (inputs_embeds is not None):
262
- raise ValueError(
263
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
264
- )
265
-
266
- # embed positions
267
- if inputs_embeds is None:
268
- inputs_embeds = self.embed_tokens(input_ids)
269
-
270
- hidden_states = inputs_embeds
271
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
272
-
273
- # get cos,sin vector
274
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
275
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
276
-
277
- # decoder layers
278
- all_hidden_states = () if output_hidden_states else None
279
- all_self_attns = () if output_attentions else None
280
-
281
- for layer_idx, decoder_layer in enumerate(self.layers):
282
- if output_hidden_states:
283
- all_hidden_states += (hidden_states,)
284
- layer_outputs = forward_dict["model"](
285
- decoder_layer,
286
- hidden_states,
287
- layer_idx,
288
- attention_mask=attention_mask,
289
- position_ids=position_ids,
290
- past_key_value=past_key_values,
291
- output_attentions=output_attentions,
292
- use_cache=use_cache,
293
- batch_ids=batch_ids,
294
- cos=cos,
295
- sin=sin,
296
- cache_pos_for_partitions=cache_pos_for_partitions,
297
- kvcache_partition_size=kvcache_partition_size,
298
- forward_dict=forward_dict,
299
- )
300
115
 
301
- hidden_states = layer_outputs[0]
116
+ return hidden_states, present_key_values
302
117
 
303
- updated_cache = layer_outputs[2 if output_attentions else 1]
304
118
 
305
- if output_attentions:
306
- all_self_attns += (layer_outputs[1],)
307
-
308
- hidden_states = self.final_layernorm(hidden_states)
309
-
310
- # add hidden states from the last decoder layer
311
- if output_hidden_states:
312
- all_hidden_states += (hidden_states,)
313
-
314
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
315
- next_cache = updated_cache.to_legacy_cache()
316
-
317
- return BaseModelOutputWithPast(
318
- last_hidden_state=hidden_states,
319
- past_key_values=next_cache,
320
- hidden_states=all_hidden_states,
321
- attentions=all_self_attns,
322
- )
119
+ class PhiModel(DecoderOnlyModel):
120
+ def get_last_layernorm(self):
121
+ return self._original_mod.final_layernorm
@@ -21,7 +21,6 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
24
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
25
 
27
26
 
@@ -26,13 +26,14 @@ import logging
26
26
  from abc import ABC
27
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
28
28
 
29
- import rebel # noqa: F401
30
- import torch # noqa: F401
29
+ import rebel
30
+ import torch
31
+ from rebel.compile_context import CompileContext
31
32
  from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
32
33
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
33
34
 
34
- from ....modeling_base import RBLNModel
35
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
35
+ from ....modeling import RBLNModel
36
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
36
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
38
 
38
39
 
@@ -66,7 +67,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
66
67
  class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
67
68
  """
68
69
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
69
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
70
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
70
71
 
71
72
  A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
72
73
  It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
@@ -88,49 +89,42 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
88
89
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
89
90
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
90
91
 
91
- wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
92
- wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
92
+ enc_compile_config = rbln_config.compile_cfgs[0]
93
+ dec_compile_config = rbln_config.compile_cfgs[1]
93
94
 
94
- wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
95
- wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
95
+ context = CompileContext(use_weight_sharing=False)
96
96
 
97
- enc_rbln_compile_config = rbln_config.compile_cfgs[0]
98
- dec_rbln_compile_config = rbln_config.compile_cfgs[1]
97
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
99
98
 
100
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
101
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
99
+ # Mark encoder's static tensors (cross kv states)
100
+ static_tensors = {}
101
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
102
+ if "key_value_states" in name:
103
+ static_tensors[name] = tensor
104
+ context.mark_static_address(tensor)
102
105
 
103
- enc_example_inputs[3].fill_(0)
104
- dec_example_inputs[4].fill_(-1)
106
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
105
107
 
106
- enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
107
- dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
108
+ # Mark decoder's static tensors (self kv states)
109
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
110
+ if "key_value_states" in name:
111
+ context.mark_static_address(tensor)
108
112
 
109
- enc_ir = rebel.torchscript_to_ir(
110
- enc_scripted_model,
111
- input_names=[v[0] for v in enc_rbln_compile_config.input_info],
112
- name=enc_rbln_compile_config.mod_name,
113
+ compiled_encoder = super().compile(
114
+ wrapped_model.encoder,
115
+ enc_compile_config,
116
+ example_inputs=enc_example_inputs,
117
+ compile_context=context,
113
118
  )
114
- dec_ir = rebel.torchscript_to_ir(
115
- dec_scripted_model,
116
- input_names=[v[0] for v in dec_rbln_compile_config.input_info],
117
- name=dec_rbln_compile_config.mod_name,
118
- )
119
-
120
- connections = [
121
- (enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
122
- (dec_ir.outputs[1], dec_ir.inputs[5]),
123
- ]
124
119
 
125
- compiled_model = rebel.compile(
126
- enc_ir,
127
- dec_ir,
128
- connections=connections,
129
- fusion=enc_rbln_compile_config.fusion,
130
- npu=enc_rbln_compile_config.npu,
131
- tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
120
+ compiled_decoder = super().compile(
121
+ wrapped_model.decoder,
122
+ dec_compile_config,
123
+ example_inputs=dec_example_inputs,
124
+ compile_context=context,
132
125
  )
133
- return compiled_model
126
+
127
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
134
128
 
135
129
  @classmethod
136
130
  def _get_rbln_config(
@@ -204,7 +198,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
204
198
  ],
205
199
  "float32",
206
200
  ),
207
- ("batch_idx", [], "int32"),
201
+ ("batch_position", [], "int16"),
208
202
  ]
209
203
 
210
204
  dec_input_info = [
@@ -216,17 +210,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
216
210
  [rbln_batch_size, 1],
217
211
  "int32",
218
212
  ),
219
- ("batch_position", [], "int32"),
220
213
  ]
221
214
  dec_input_info.extend(
222
215
  [
223
216
  (
224
- "self_key_value_states",
217
+ "cross_key_value_states",
225
218
  [
226
219
  n_layer * 2,
227
220
  rbln_batch_size,
228
221
  n_head,
229
- rbln_dec_max_seq_len,
222
+ rbln_enc_max_seq_len,
230
223
  d_kv,
231
224
  ],
232
225
  "float32",
@@ -236,24 +229,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
236
229
  dec_input_info.extend(
237
230
  [
238
231
  (
239
- "cross_key_value_states",
232
+ f"self_key_value_states_{i}",
240
233
  [
241
- n_layer * 2,
242
234
  rbln_batch_size,
243
235
  n_head,
244
- rbln_enc_max_seq_len,
236
+ rbln_dec_max_seq_len,
245
237
  d_kv,
246
238
  ],
247
239
  "float32",
248
240
  )
241
+ for i in range(n_layer * 2)
249
242
  ]
250
243
  )
251
- enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
252
- dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
244
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
245
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
253
246
 
254
247
  rbln_config = RBLNConfig(
255
248
  rbln_cls=cls.__name__,
256
- compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
249
+ compile_cfgs=[enc_compile_config, dec_compile_config],
257
250
  rbln_kwargs=rbln_kwargs,
258
251
  )
259
252
 
@@ -270,12 +263,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
270
263
 
271
264
  @classmethod
272
265
  def _create_runtimes(
273
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
266
+ cls,
267
+ compiled_models: List[rebel.RBLNCompiledModel],
268
+ rbln_device_map: Dict[str, int],
269
+ activate_profiler: Optional[bool] = None,
274
270
  ) -> List[rebel.Runtime]:
275
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
271
+ if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
272
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
273
+
276
274
  return [
277
- compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
278
- compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
275
+ compiled_models[0].create_runtime(
276
+ tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
277
+ ),
278
+ compiled_models[1].create_runtime(
279
+ tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
280
+ ),
279
281
  ]
280
282
 
281
283
  def can_generate(self):
@@ -340,57 +342,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
340
342
  attention_mask=dec_attention_mask,
341
343
  encoder_attention_mask=attention_mask,
342
344
  cache_position=cache_position,
343
- batch_position=torch.tensor(0, dtype=torch.int32),
344
345
  )
345
- lm_logits = decoder_output.logits[0]
346
+ lm_logits = decoder_output.logits
346
347
 
347
348
  return Seq2SeqLMOutput(logits=lm_logits)
348
349
 
349
- def vllm_forward(
350
- self,
351
- input_ids: torch.LongTensor = None,
352
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
353
- batch_idx: Optional[torch.LongTensor] = None,
354
- enc_lengths: List[int] = None, # vllm return current attention_mask length
355
- **kwargs,
356
- ) -> Tuple[torch.FloatTensor]:
357
- # When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
358
- # encoder
359
- if batch_idx is not None:
360
- enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
361
- enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
362
- padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
363
- input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
364
- _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
365
- logits = torch.zeros(1, 1, self.config.vocab_size + 100)
366
- logits[0][0][-1] = 1
367
- # decoder
368
- else:
369
- input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
370
- cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
371
-
372
- enc_attention_mask = torch.zeros(
373
- self.rbln_config.model_cfg["batch_size"],
374
- self.rbln_config.model_cfg["enc_max_seq_len"],
375
- dtype=torch.float32,
376
- )
377
- dec_attention_mask = torch.zeros(
378
- self.rbln_config.model_cfg["batch_size"],
379
- self.rbln_config.model_cfg["dec_max_seq_len"],
380
- dtype=torch.float32,
381
- )
382
- for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
383
- enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
384
-
385
- logits = self._forward_decoder(
386
- attention_mask=enc_attention_mask,
387
- decoder_input_ids=input_ids,
388
- decoder_attention_mask=dec_attention_mask,
389
- cache_position=cache_position,
390
- ).logits
391
-
392
- return Seq2SeqLMOutput(logits=logits)
393
-
394
350
  def _prepare_encoder_decoder_kwargs_for_generation(
395
351
  self,
396
352
  inputs_tensor: torch.Tensor,
@@ -426,15 +382,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
426
382
  )
427
383
 
428
384
  # 3. make sure that encoder returns `ModelOutput`
429
- model_input_name = model_input_name if model_input_name is not None else self.main_input_name
430
385
  encoder_kwargs["return_dict"] = True
431
386
  encoder_kwargs["output_hidden_states"] = False
432
387
  encoder_kwargs["output_attentions"] = False
433
388
 
434
389
  for b in range(batch_size):
435
- batch_idx = torch.tensor(b, dtype=torch.int32)
390
+ batch_position = torch.tensor(b, dtype=torch.int16)
436
391
  encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
437
392
  encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
438
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
393
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_position=batch_position)
439
394
 
440
395
  return model_kwargs