optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,302 +21,102 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import math
25
- from typing import Dict, Optional, Tuple
24
+ from typing import TYPE_CHECKING, Optional, Tuple
26
25
 
27
26
  import torch
28
- import torch.nn as nn
29
- from transformers.modeling_outputs import (
30
- BaseModelOutputWithPast,
31
- )
27
+ from transformers import PhiForCausalLM
32
28
 
33
- from ...cache_utils import RebelDynamicCache
34
- from ..decoderonly import (
29
+ from ..decoderonly.decoderonly_architecture import (
30
+ DecoderOnlyAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
33
+ DecoderOnlyModel,
35
34
  DecoderOnlyWrapper,
36
- apply_rotary_pos_emb,
37
- slice_and_unsqueeze_cos_sin,
35
+ apply_rotary_pos_emb_partial,
38
36
  )
39
37
 
40
38
 
41
- class PhiWrapper(DecoderOnlyWrapper):
42
- def get_forward_dict(self):
43
- forward_dict = {}
44
- forward_dict.update(
45
- {
46
- "wrapper": PhiModel.forward,
47
- "model": PhiDecoderLayer.forward,
48
- "decoder_layer": PhiAttention.forward,
49
- }
50
- )
51
- return forward_dict
52
-
53
-
54
- class PhiAttention:
55
- def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
56
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
57
- key_state = key_state.unsqueeze(2)
58
- value_state = value_state.unsqueeze(2)
59
- attn_mask = attn_mask.unsqueeze(2)
60
-
61
- query_state = query_state.view(
62
- 1,
63
- self.num_key_value_heads,
64
- self.num_heads // self.num_key_value_heads,
65
- -1,
66
- self.head_dim,
67
- )
68
-
69
- key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
70
-
71
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
72
- attn_weights = torch.matmul(
73
- query_state.to(torch.float32),
74
- key_state.to(torch.float32).transpose(3, 4),
75
- ) / math.sqrt(self.head_dim)
76
- attn_weights = attn_weights + attn_mask
77
-
78
- # upcast attention to fp32
79
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
80
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
81
- attn_output = torch.matmul(attn_weights, value_state)
82
-
83
- # reshape for removing repeat_kv
84
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
85
- attn_output = attn_output.transpose(1, 2).contiguous()
86
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
39
+ if TYPE_CHECKING:
40
+ from transformers import PhiForCausalLM
87
41
 
88
- return attn_output, key_state, value_state
89
-
90
- def forward(
91
- self,
92
- hidden_states: torch.Tensor,
93
- attention_mask: Optional[torch.Tensor] = None,
94
- past_key_value: Optional[RebelDynamicCache] = None,
95
- batch_index: Optional[int] = None,
96
- output_attentions: bool = False,
97
- cos: Optional[torch.Tensor] = None,
98
- sin: Optional[torch.Tensor] = None,
99
- **kwargs,
100
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
101
- bsz, q_len, _ = hidden_states.size()
102
42
 
43
+ class PhiWrapper(DecoderOnlyWrapper):
44
+ def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
45
+ new_layers = []
46
+ for layer in causal_lm.model.layers:
47
+ if self.attn_impl == "eager":
48
+ new_self_attn = PhiAttention(layer.self_attn)
49
+ elif self.attn_impl == "flash_attn":
50
+ raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
51
+ else:
52
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
53
+ new_layer = PhiLayer(layer, new_self_attn)
54
+ new_layers.append(new_layer)
55
+ new_model = PhiModel(causal_lm.model, new_layers)
56
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
57
+ return new_causal_lm
58
+
59
+
60
+ class PhiAttention(DecoderOnlyAttention):
61
+ def __post_init__(self):
62
+ self.q_proj = self._original_mod.q_proj
63
+ self.k_proj = self._original_mod.k_proj
64
+ self.v_proj = self._original_mod.v_proj
65
+ self.o_proj = self._original_mod.dense
66
+ self.qk_layernorm = self._original_mod.qk_layernorm
67
+ self.rotary_ndims = self._original_mod.rotary_ndims
68
+ self.num_key_value_heads = self.num_heads
69
+
70
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
71
  query_states = self.q_proj(hidden_states)
104
72
  key_states = self.k_proj(hidden_states)
105
73
  value_states = self.v_proj(hidden_states)
106
74
 
107
75
  if self.qk_layernorm:
108
- query_states = self.q_layernorm(query_states)
109
- key_states = self.k_layernorm(key_states)
110
-
111
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
112
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
114
-
115
- # Partial rotary embedding
116
- query_rot, query_pass = (
117
- query_states[..., : self.rotary_ndims],
118
- query_states[..., self.rotary_ndims :],
119
- )
120
- key_rot, key_pass = (
121
- key_states[..., : self.rotary_ndims],
122
- key_states[..., self.rotary_ndims :],
123
- )
76
+ query_states = self._original_mod.q_layernorm(query_states)
77
+ key_states = self._original_mod.k_layernorm(key_states)
124
78
 
125
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
126
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
79
+ return query_states, key_states, value_states
127
80
 
128
- # [batch_size, seq_length, num_heads, head_dim]
129
- query_states = torch.cat((query_rot, query_pass), dim=-1)
130
- key_states = torch.cat((key_rot, key_pass), dim=-1)
81
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
82
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
131
83
 
132
- # Decoder (bsz > 1)
133
- if bsz > 1:
134
- iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
135
- for b in range(bsz):
136
- attn_output, key_state, value_state = PhiAttention._attn(
137
- self,
138
- query_states[b].unsqueeze(0),
139
- key_states[b].unsqueeze(0),
140
- value_states[b].unsqueeze(0),
141
- attention_mask[b].unsqueeze(0),
142
- past_key_value,
143
- batch_idx=b,
144
- is_prefill=False,
145
- )
146
- iterate_results["key_states"].append(key_state)
147
- iterate_results["value_states"].append(value_state)
148
- iterate_results["attn_output"].append(attn_output)
149
84
 
150
- key_states = torch.cat(iterate_results["key_states"], dim=0)
151
- value_states = torch.cat(iterate_results["value_states"], dim=0)
152
- attn_output = torch.cat(iterate_results["attn_output"], dim=0)
153
- # Prefill & Decoder (bsz == 1)
154
- else:
155
- attn_output, key_states, value_states = PhiAttention._attn(
156
- self,
157
- query_states,
158
- key_states,
159
- value_states,
160
- attention_mask,
161
- past_key_value,
162
- batch_idx=batch_index,
163
- is_prefill=True,
164
- )
85
+ class PhiLayer(DecoderOnlyLayer):
86
+ def get_post_attention_layernorm(self):
87
+ raise NotImplementedError
165
88
 
166
- attn_output = self.dense(attn_output)
167
-
168
- if not output_attentions:
169
- attn_weights = None
170
-
171
- return attn_output, attn_weights, key_states, value_states
172
-
173
-
174
- class PhiDecoderLayer:
175
89
  def forward(
176
90
  self,
177
91
  hidden_states: torch.Tensor,
178
- layer_idx: int,
179
- attention_mask: Optional[torch.Tensor] = None,
180
- position_ids: Optional[torch.LongTensor] = None,
181
- past_key_value: Optional[RebelDynamicCache] = None,
182
- output_attentions: Optional[bool] = None,
183
- use_cache: Optional[bool] = None,
184
- batch_ids: Optional[torch.LongTensor] = None,
92
+ attention_mask: torch.Tensor,
93
+ current_steps: torch.LongTensor,
94
+ batch_position: torch.Tensor,
95
+ past_key_values: Tuple[Tuple[torch.Tensor]],
185
96
  cos: Optional[torch.Tensor] = None,
186
97
  sin: Optional[torch.Tensor] = None,
187
- forward_dict: Optional[Dict[str, classmethod]] = None,
188
- **kwargs,
189
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
190
- """
191
- Args:
192
- hidden_states (`torch.FloatTensor`):
193
- input to the layer of shape `(batch, seq_len, embed_dim)`
194
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
195
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
196
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
197
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
198
- `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
199
- output_attentions (`bool`, *optional*):
200
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
201
- returned tensors for more detail.
202
- use_cache (`bool`, *optional*):
203
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
204
- (see `past_key_values`).
205
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
206
- """
207
-
98
+ ):
208
99
  residual = hidden_states
209
100
 
210
- hidden_states = self.input_layernorm(hidden_states)
101
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
211
102
 
212
- # Self Attention
213
- attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
214
- self.self_attn,
103
+ attn_outputs, present_key_values = self.self_attn(
215
104
  hidden_states=hidden_states,
216
105
  attention_mask=attention_mask,
217
- position_ids=position_ids,
218
- past_key_value=past_key_value,
219
- output_attentions=output_attentions,
220
- batch_index=batch_ids,
221
- use_cache=use_cache,
106
+ current_steps=current_steps,
107
+ batch_position=batch_position,
108
+ past_key_values=past_key_values,
222
109
  cos=cos,
223
110
  sin=sin,
224
- **kwargs,
225
111
  )
226
- past_key_value.assign(key_states, value_states, layer_idx)
227
112
 
228
- attn_outputs = self.resid_dropout(attn_outputs)
113
+ feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
229
114
 
230
- feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
231
115
  hidden_states = attn_outputs + feed_forward_hidden_states + residual
232
- outputs = (hidden_states,)
233
-
234
- if output_attentions:
235
- outputs += (self_attn_weights,)
236
-
237
- if use_cache:
238
- outputs += (past_key_value,)
239
-
240
- return outputs
241
-
242
-
243
- class PhiModel:
244
- def forward(
245
- self,
246
- input_ids: torch.LongTensor = None,
247
- attention_mask: Optional[torch.Tensor] = None,
248
- position_ids: Optional[torch.LongTensor] = None,
249
- past_key_values: Optional[RebelDynamicCache] = None,
250
- batch_ids: Optional[torch.LongTensor] = None,
251
- inputs_embeds: Optional[torch.FloatTensor] = None,
252
- use_cache: Optional[bool] = True,
253
- output_attentions: Optional[bool] = False,
254
- output_hidden_states: Optional[bool] = False,
255
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
256
- kvcache_partition_size: Optional[torch.Tensor] = None,
257
- forward_dict: Optional[Dict[str, classmethod]] = None,
258
- rotary_pos_emb=None,
259
- ) -> BaseModelOutputWithPast:
260
- # retrieve input_ids and inputs_embeds
261
- if (input_ids is None) ^ (inputs_embeds is not None):
262
- raise ValueError(
263
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
264
- )
265
-
266
- # embed positions
267
- if inputs_embeds is None:
268
- inputs_embeds = self.embed_tokens(input_ids)
269
-
270
- hidden_states = inputs_embeds
271
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
272
-
273
- # get cos,sin vector
274
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
275
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
276
-
277
- # decoder layers
278
- all_hidden_states = () if output_hidden_states else None
279
- all_self_attns = () if output_attentions else None
280
-
281
- for layer_idx, decoder_layer in enumerate(self.layers):
282
- if output_hidden_states:
283
- all_hidden_states += (hidden_states,)
284
- layer_outputs = forward_dict["model"](
285
- decoder_layer,
286
- hidden_states,
287
- layer_idx,
288
- attention_mask=attention_mask,
289
- position_ids=position_ids,
290
- past_key_value=past_key_values,
291
- output_attentions=output_attentions,
292
- use_cache=use_cache,
293
- batch_ids=batch_ids,
294
- cos=cos,
295
- sin=sin,
296
- cache_pos_for_partitions=cache_pos_for_partitions,
297
- kvcache_partition_size=kvcache_partition_size,
298
- forward_dict=forward_dict,
299
- )
300
116
 
301
- hidden_states = layer_outputs[0]
117
+ return hidden_states, present_key_values
302
118
 
303
- updated_cache = layer_outputs[2 if output_attentions else 1]
304
119
 
305
- if output_attentions:
306
- all_self_attns += (layer_outputs[1],)
307
-
308
- hidden_states = self.final_layernorm(hidden_states)
309
-
310
- # add hidden states from the last decoder layer
311
- if output_hidden_states:
312
- all_hidden_states += (hidden_states,)
313
-
314
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
315
- next_cache = updated_cache.to_legacy_cache()
316
-
317
- return BaseModelOutputWithPast(
318
- last_hidden_state=hidden_states,
319
- past_key_values=next_cache,
320
- hidden_states=all_hidden_states,
321
- attentions=all_self_attns,
322
- )
120
+ class PhiModel(DecoderOnlyModel):
121
+ def get_last_layernorm(self):
122
+ return self._original_mod.final_layernorm
@@ -31,7 +31,7 @@ import torch # noqa: F401
31
31
  from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
32
32
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
33
33
 
34
- from ....modeling_base import RBLNModel
34
+ from ....modeling import RBLNModel
35
35
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
37
 
@@ -346,51 +346,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
346
346
 
347
347
  return Seq2SeqLMOutput(logits=lm_logits)
348
348
 
349
- def vllm_forward(
350
- self,
351
- input_ids: torch.LongTensor = None,
352
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
353
- batch_idx: Optional[torch.LongTensor] = None,
354
- enc_lengths: List[int] = None, # vllm return current attention_mask length
355
- **kwargs,
356
- ) -> Tuple[torch.FloatTensor]:
357
- # When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
358
- # encoder
359
- if batch_idx is not None:
360
- enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
361
- enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
362
- padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
363
- input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
364
- _ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
365
- logits = torch.zeros(1, 1, self.config.vocab_size + 100)
366
- logits[0][0][-1] = 1
367
- # decoder
368
- else:
369
- input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
370
- cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
371
-
372
- enc_attention_mask = torch.zeros(
373
- self.rbln_config.model_cfg["batch_size"],
374
- self.rbln_config.model_cfg["enc_max_seq_len"],
375
- dtype=torch.float32,
376
- )
377
- dec_attention_mask = torch.zeros(
378
- self.rbln_config.model_cfg["batch_size"],
379
- self.rbln_config.model_cfg["dec_max_seq_len"],
380
- dtype=torch.float32,
381
- )
382
- for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
383
- enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
384
-
385
- logits = self._forward_decoder(
386
- attention_mask=enc_attention_mask,
387
- decoder_input_ids=input_ids,
388
- decoder_attention_mask=dec_attention_mask,
389
- cache_position=cache_position,
390
- ).logits
391
-
392
- return Seq2SeqLMOutput(logits=logits)
393
-
394
349
  def _prepare_encoder_decoder_kwargs_for_generation(
395
350
  self,
396
351
  inputs_tensor: torch.Tensor,
@@ -22,17 +22,23 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import inspect
25
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
26
26
 
27
+ import torch
28
+ import transformers
27
29
  from transformers import (
28
30
  AutoModelForTextEncoding,
29
31
  PretrainedConfig,
32
+ T5EncoderModel,
30
33
  T5ForConditionalGeneration,
31
34
  )
35
+ from transformers.modeling_outputs import BaseModelOutput
32
36
 
33
- from ....modeling_base import RBLNModel
37
+ from ....modeling import RBLNModel
34
38
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
39
+ from ....modeling_diffusers import RBLNDiffusionMixin
35
40
  from ....utils.logging import get_logger
41
+ from ....utils.runtime_utils import RBLNPytorchRuntime
36
42
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
37
43
  from .t5_architecture import T5Wrapper
38
44
 
@@ -43,8 +49,60 @@ if TYPE_CHECKING:
43
49
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
44
50
 
45
51
 
52
+ class RBLNRuntimeModel(RBLNPytorchRuntime):
53
+ def forward(
54
+ self,
55
+ input_ids: torch.LongTensor,
56
+ attention_mask: torch.FloatTensor,
57
+ head_mask: torch.FloatTensor,
58
+ inputs_embeds: torch.FloatTensor,
59
+ **kwargs,
60
+ ):
61
+ return super().forward(
62
+ input_ids,
63
+ attention_mask,
64
+ head_mask,
65
+ inputs_embeds,
66
+ **kwargs,
67
+ )
68
+
69
+
70
+ class T5EncoderWrapper(torch.nn.Module):
71
+ def __init__(self, model: "T5EncoderModel") -> None:
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ def forward(self, *args, **kwargs):
76
+ kwargs.pop("return_dict", None)
77
+ return self.model(*args, **kwargs, return_dict=False)
78
+
79
+
46
80
  class RBLNT5EncoderModel(RBLNModel):
47
81
  auto_model_class = AutoModelForTextEncoding
82
+ rbln_model_input_names = ["input_ids", "attention_mask"]
83
+
84
+ def __post_init__(self, **kwargs):
85
+ self.model = RBLNRuntimeModel(runtime=self.model[0])
86
+
87
+ @classmethod
88
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
89
+ return T5EncoderWrapper(model)
90
+
91
+ @classmethod
92
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
93
+ batch_size = rbln_config.get("batch_size", 1)
94
+ max_sequence_length = rbln_config.get("max_sequence_length", 256)
95
+ model_input_names = ["input_ids"]
96
+
97
+ rbln_config.update(
98
+ {
99
+ "batch_size": batch_size,
100
+ "max_seq_len": max_sequence_length,
101
+ "model_input_names": model_input_names,
102
+ }
103
+ )
104
+
105
+ return rbln_config
48
106
 
49
107
  @classmethod
50
108
  def _get_rbln_config(
@@ -54,6 +112,7 @@ class RBLNT5EncoderModel(RBLNModel):
54
112
  rbln_kwargs: Dict[str, Any] = {},
55
113
  ) -> RBLNConfig:
56
114
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
115
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
57
116
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
58
117
 
59
118
  max_position_embeddings = getattr(model_config, "n_positions", None)
@@ -71,12 +130,27 @@ class RBLNT5EncoderModel(RBLNModel):
71
130
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
72
131
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
73
132
 
133
+ if rbln_model_input_names is None:
134
+ for tokenizer in preprocessors:
135
+ if hasattr(tokenizer, "model_input_names"):
136
+ rbln_model_input_names = tokenizer.model_input_names
137
+ break
138
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
139
+ rbln_model_input_names = cls.rbln_model_input_names
140
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
141
+ original_model_class = getattr(transformers, model_config.architectures[0])
142
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
143
+ raise ValueError(
144
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
145
+ f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(input_names_order)})"
146
+ )
147
+
74
148
  if rbln_batch_size is None:
75
149
  rbln_batch_size = 1
76
150
 
77
151
  input_info = [
78
- ("input_ids", [rbln_batch_size, rbln_max_seq_len], "int64"),
79
- ("attention_mask", [rbln_batch_size, rbln_max_seq_len], "int64"),
152
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
153
+ for model_input_name in rbln_model_input_names
80
154
  ]
81
155
 
82
156
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
@@ -90,6 +164,30 @@ class RBLNT5EncoderModel(RBLNModel):
90
164
  rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
91
165
  return rbln_config
92
166
 
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.LongTensor] = None,
170
+ attention_mask: Optional[torch.FloatTensor] = None,
171
+ head_mask: Optional[torch.FloatTensor] = None,
172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
173
+ output_attentions: Optional[bool] = None,
174
+ output_hidden_states: Optional[bool] = None,
175
+ return_dict: Optional[bool] = None,
176
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
177
+ encoder_outputs = self.model(
178
+ input_ids=input_ids,
179
+ attention_mask=attention_mask,
180
+ inputs_embeds=inputs_embeds,
181
+ head_mask=head_mask,
182
+ output_attentions=output_attentions,
183
+ output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict,
185
+ )
186
+ if not return_dict:
187
+ return (encoder_outputs,)
188
+ else:
189
+ return BaseModelOutput(last_hidden_state=encoder_outputs)
190
+
93
191
 
94
192
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
95
193
  @classmethod
@@ -28,7 +28,7 @@ import torch
28
28
  from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
29
29
  from transformers.modeling_outputs import CausalLMOutput
30
30
 
31
- from ....modeling_base import RBLNModel
31
+ from ....modeling import RBLNModel
32
32
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
33
33
 
34
34
 
@@ -36,7 +36,7 @@ from transformers import (
36
36
  )
37
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
38
 
39
- from ....modeling_base import RBLNModel
39
+ from ....modeling import RBLNModel
40
40
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
42
  from .generation_whisper import RBLNWhisperGenerationMixin
@@ -22,12 +22,12 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Optional, Union
26
26
 
27
27
  import torch
28
- from transformers import PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
28
+ from transformers import PretrainedConfig
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -38,38 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNXLMRobertaModel(RBLNModel):
41
- original_model_class = XLMRobertaModel
42
- original_config_class = XLMRobertaConfig
43
-
44
- @classmethod
45
- def get_pytorch_model(
46
- cls,
47
- model_id: str,
48
- use_auth_token: Optional[Union[bool, str]] = None,
49
- revision: Optional[str] = None,
50
- force_download: bool = False,
51
- cache_dir: Optional[str] = None,
52
- subfolder: str = "",
53
- local_files_only: bool = False,
54
- trust_remote_code: bool = False,
55
- rbln_kwargs: Optional[Dict[str, Any]] = None,
56
- **kwargs,
57
- ) -> "PreTrainedModel":
58
- model: "PreTrainedModel" = super().get_pytorch_model(
59
- model_id=model_id,
60
- use_auth_token=use_auth_token,
61
- revision=revision,
62
- force_download=force_download,
63
- cache_dir=cache_dir,
64
- subfolder=subfolder,
65
- local_files_only=local_files_only,
66
- trust_remote_code=trust_remote_code,
67
- rbln_kwargs=rbln_kwargs,
68
- library_name="transformers",
69
- )
70
-
71
- return model
72
-
73
41
  @classmethod
74
42
  def _get_rbln_config(
75
43
  cls,