optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,18 +21,24 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple, Union
24
+ from typing import TYPE_CHECKING, Tuple
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from transformers.modeling_outputs import BaseModelOutputWithPast
29
28
 
30
- from ....transformers.models.decoderonly.decoderonly_architecture import (
31
- RotaryEmbedding,
32
- rotate_half,
33
- slice_and_unsqueeze_cos_sin,
29
+ from ....transformers.models.decoderonly.decoderonly_architecture import rotate_half
30
+ from ..decoderonly.decoderonly_architecture import (
31
+ DecoderOnlyAttention,
32
+ DecoderOnlyForCausalLM,
33
+ DecoderOnlyLayer,
34
+ DecoderOnlyModel,
35
+ DecoderOnlyWrapper,
36
+ apply_rotary_pos_emb_partial,
34
37
  )
35
- from ...cache_utils import RebelDynamicCache_4D
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from transformers import PreTrainedModel as MidmLMHeadModel
36
42
 
37
43
 
38
44
  def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
@@ -50,253 +56,93 @@ def apply_rotary_pos_emb(q, k, cos, sin):
50
56
  return q_embed, k_embed
51
57
 
52
58
 
53
- class MidmLMHeadModelWrapper(torch.nn.Module):
54
- """A wrapper class for the Midm model with a language modeling head."""
55
-
56
- def __init__(self, model, max_seq_len):
57
- super().__init__()
58
- self.model = model.transformer
59
- self.lm_head = model.lm_head
60
- self.config = model.config
61
- self.max_seq_len = max_seq_len
62
-
63
- self.config.partial_rotary_factor = model.config.rotary_percentage
64
- self.config.head_dim = self.config.n_embd // self.config.n_head
59
+ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
60
+ def get_rotary_emb(self, max_seq_len):
65
61
  self.config.rope_theta = 10000
66
- self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
67
-
68
- def forward(
69
- self,
70
- input_ids: torch.Tensor,
71
- attention_mask: torch.Tensor,
72
- cache_position: torch.LongTensor,
73
- batch_position: int,
74
- query_idx: int,
75
- *past_key_values,
76
- ):
77
- """Defines the forward pass for the wrapper model."""
78
- if input_ids.shape[1] == 1:
79
- rbln_batch_position = None
80
- else:
81
- rbln_batch_position = batch_position
82
-
83
- past_key_values = RebelDynamicCache_4D.from_input_format(
84
- cache_position,
85
- self.config.num_hidden_layers,
86
- *past_key_values,
87
- )
88
-
89
- outputs = _MidmModel.forward(
90
- self.model,
91
- input_ids=input_ids,
92
- past_key_values=past_key_values,
93
- attention_mask=attention_mask,
94
- position_ids=cache_position,
95
- rotary_pos_emb=self.rotary_emb,
96
- batch_ids=rbln_batch_position,
97
- )
98
-
99
- hidden_states = outputs[0]
100
- if batch_position >= 0:
101
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
102
-
103
- logits = self.lm_head(hidden_states)
104
- output = (logits,) + outputs[1:]
105
-
106
- return output, batch_position + query_idx
107
-
108
-
109
- def layernorm1p(module, input):
110
- """Applies Layer Normalization with a slight modification on the weights."""
111
- return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
112
-
113
-
114
- class _MidmAttention:
115
- """Custom implementation of the MidmAttention class with specific modifications."""
116
-
117
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
118
- """Computes the attention weights and output."""
119
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
120
-
121
- if self.scale_attn_weights:
122
- attn_weights = attn_weights / torch.full(
123
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
62
+ self.config.head_dim = self.config.n_embd // self.config.n_head
63
+ self.config.partial_rotary_factor = self.config.rotary_percentage
64
+ return super().get_rotary_emb(max_seq_len=max_seq_len)
65
+
66
+ def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
67
+ if self.attn_impl != "eager":
68
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
69
+ new_layers = []
70
+ for layer in causal_lm.transformer.h:
71
+ new_self_attn = MidmAttention(layer.attn)
72
+ new_layer = MidmLayer(layer, new_self_attn)
73
+ new_layers.append(new_layer)
74
+ new_model = MidmModel(causal_lm.transformer, new_layers)
75
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
76
+ return new_causal_lm
77
+
78
+
79
+ class MidmModel(DecoderOnlyModel):
80
+ mask_fmin = -10000.0
81
+
82
+ def get_layernorm1p(self, module: nn.LayerNorm):
83
+ def layernorm1p(input: torch.Tensor):
84
+ """Applies Layer Normalization with a slight modification on the weights."""
85
+ return torch.nn.functional.layer_norm(
86
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
124
87
  )
125
88
 
126
- if self.scale_attn_by_inverse_layer_idx or self.scale_qk_by_inverse_layer_idx:
127
- attn_weights = attn_weights / float(self.layer_idx + 1)
89
+ return layernorm1p
128
90
 
129
- if attention_mask is not None:
130
- attn_weights = attn_weights + attention_mask
131
-
132
- if self.scale_qk_by_inverse_layer_idx:
133
- attn_weights = attn_weights * float(self.layer_idx + 1)
134
-
135
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
136
- attn_weights = attn_weights.type(value.dtype)
137
-
138
- if head_mask is not None:
139
- attn_weights = attn_weights * head_mask
140
-
141
- attn_output = torch.matmul(attn_weights, value)
142
- return attn_output, attn_weights
143
-
144
- def forward(
145
- self,
146
- hidden_states: Optional[Tuple[torch.FloatTensor]],
147
- attention_mask: Optional[torch.FloatTensor] = None,
148
- past_key_value: Optional[RebelDynamicCache_4D] = None,
149
- batch_index: Optional[int] = None,
150
- cos: Optional[torch.Tensor] = None,
151
- sin: Optional[torch.Tensor] = None,
152
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
153
- """Defines the forward pass for the attention mechanism."""
154
- bsz, q_len, _ = hidden_states.size()
155
-
156
- querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
157
-
158
- querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
159
- keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
160
- values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
161
-
162
- querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
163
-
164
- # Decoder
165
- if (batch_index is None or batch_index == -1) and bsz > 1:
166
- all_key_states = []
167
- all_value_states = []
168
- all_attn_output = []
91
+ def get_last_layernorm(self) -> nn.LayerNorm:
92
+ if self._original_mod.use_layernorm1p:
93
+ return self.get_layernorm1p(self._original_mod.ln_f)
94
+ else:
95
+ return self._original_mod.ln_f
169
96
 
170
- for b in range(bsz):
171
- query = querys[b].unsqueeze(0)
172
- attn_mask = attention_mask[b].unsqueeze(0)
173
- key = keys[b].unsqueeze(0)
174
- value = values[b].unsqueeze(0)
97
+ def get_embedding(self) -> nn.Embedding:
98
+ return self._original_mod.wte
175
99
 
176
- key, value = past_key_value.update(
177
- key,
178
- value,
179
- self.layer_idx,
180
- b,
181
- )
100
+ def get_pos_embedding(self) -> nn.Embedding:
101
+ return self._original_mod.wpe
182
102
 
183
- attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
184
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
185
103
 
186
- all_key_states.append(key)
187
- all_value_states.append(value)
188
- all_attn_output.append(attn_output)
104
+ class MidmLayer(DecoderOnlyLayer):
105
+ def get_layernorm1p(self, module: nn.LayerNorm):
106
+ def layernorm1p(input: torch.Tensor):
107
+ """Applies Layer Normalization with a slight modification on the weights."""
108
+ return torch.nn.functional.layer_norm(
109
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
110
+ )
189
111
 
190
- keys = torch.cat(all_key_states, dim=0)
191
- values = torch.cat(all_value_states, dim=0)
192
- attn_output = torch.cat(all_attn_output, dim=0)
112
+ return layernorm1p
193
113
 
114
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
115
+ if self._original_mod.use_layernorm1p:
116
+ return self.get_layernorm1p(self._original_mod.ln_1)
194
117
  else:
195
- if batch_index is None or batch_index == -1:
196
- batch_index = 0
197
-
198
- keys, values = past_key_value.update(
199
- keys,
200
- values,
201
- self.layer_idx,
202
- batch_index,
203
- read_first_step=True,
204
- )
118
+ return self._original_mod.ln_1
205
119
 
206
- attn_output, _ = _MidmAttention._attn(self, querys, keys, values, attention_mask)
207
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
208
-
209
- attn_output = self.c_proj(attn_output)
210
- return attn_output, keys, values
211
-
212
-
213
- class _MidmBlock:
214
- """Custom implementation of the MidmBlock class with specific modifications."""
215
-
216
- def forward(
217
- self,
218
- hidden_states: Optional[Tuple[torch.FloatTensor]],
219
- layer_idx: int,
220
- attention_mask: Optional[torch.FloatTensor] = None,
221
- past_key_value: Optional[RebelDynamicCache_4D] = None,
222
- batch_ids: Optional[torch.LongTensor] = None,
223
- cos: Optional[torch.Tensor] = None,
224
- sin: Optional[torch.Tensor] = None,
225
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
226
- """Defines the forward pass for the block."""
227
- residual = hidden_states
228
- if self.use_layernorm1p:
229
- hidden_states = layernorm1p(self.ln_1, hidden_states)
120
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
121
+ if self._original_mod.use_layernorm1p:
122
+ return self.get_layernorm1p(self._original_mod.ln_2)
230
123
  else:
231
- hidden_states = self.ln_1(hidden_states)
232
-
233
- hidden_states, k, v = _MidmAttention.forward(
234
- self.attn,
235
- hidden_states,
236
- attention_mask=attention_mask,
237
- past_key_value=past_key_value,
238
- cos=cos,
239
- sin=sin,
240
- batch_index=batch_ids,
124
+ return self._original_mod.ln_2
125
+
126
+
127
+ class MidmAttention(DecoderOnlyAttention):
128
+ def __post_init__(self):
129
+ self.c_attn = self._original_mod.c_attn
130
+ self.o_proj = self._original_mod.c_proj
131
+ self.split_size = self._original_mod.split_size
132
+ self.num_key_value_heads = self._original_mod.num_heads
133
+
134
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
136
+ return query_states, key_states, value_states
137
+
138
+ def rbln_attention(self, *args, **kwargs):
139
+ return super().rbln_attention(
140
+ *args,
141
+ **kwargs,
142
+ layer_idx=self.layer_idx,
143
+ scale_attn_weights=self._original_mod.scale_attn_weights,
144
+ scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
241
145
  )
242
- past_key_value.assign(k, v, layer_idx)
243
-
244
- hidden_states = hidden_states + residual
245
146
 
246
- residual = hidden_states
247
- if self.use_layernorm1p:
248
- hidden_states = layernorm1p(self.ln_2, hidden_states)
249
- else:
250
- hidden_states = self.ln_2(hidden_states)
251
-
252
- feed_forward_hidden_states = self.mlp(hidden_states)
253
- hidden_states = residual + feed_forward_hidden_states
254
-
255
- return hidden_states, past_key_value
256
-
257
-
258
- class _MidmModel:
259
- """Custom implementation of the MidmModel class with specific modifications."""
260
-
261
- def forward(
262
- self,
263
- input_ids: Optional[torch.LongTensor] = None,
264
- past_key_values: Optional[RebelDynamicCache_4D] = None,
265
- attention_mask: Optional[torch.FloatTensor] = None,
266
- position_ids: Optional[torch.LongTensor] = None,
267
- rotary_pos_emb=None,
268
- batch_ids: Optional[torch.LongTensor] = None,
269
- ) -> Union[Tuple, BaseModelOutputWithPast]:
270
- """Defines the forward pass for the model."""
271
- input_shape = input_ids.size()
272
-
273
- attention_mask = (1.0 - attention_mask) * -10000.0
274
-
275
- inputs_embeds = self.wte(input_ids)
276
-
277
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
278
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
279
- hidden_states = inputs_embeds
280
-
281
- for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
282
- hidden_states, updated_cache = _MidmBlock.forward(
283
- block,
284
- hidden_states,
285
- layer_idx,
286
- attention_mask=attention_mask,
287
- past_key_value=past_key_values,
288
- batch_ids=batch_ids,
289
- cos=cos,
290
- sin=sin,
291
- )
292
-
293
- hidden_states = layernorm1p(self.ln_f, hidden_states)
294
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
295
- hidden_states = hidden_states.view(output_shape)
296
-
297
- next_cache = updated_cache.to_legacy_cache()
298
-
299
- return BaseModelOutputWithPast(
300
- last_hidden_state=hidden_states,
301
- past_key_values=next_cache,
302
- )
147
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
148
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
@@ -21,12 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
24
27
  from ....utils import logging
25
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
- from .hf_hub_cached.modeling_midm import MidmLMHeadModel
27
- from .midm_architecture import (
28
- MidmLMHeadModelWrapper,
29
- )
28
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
29
+ from .midm_architecture import MidmLMHeadModelWrapper
30
30
 
31
31
 
32
32
  logger = logging.get_logger(__name__)
@@ -47,7 +47,7 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
47
47
  """
48
48
 
49
49
  _decoder_wrapper_cls = MidmLMHeadModelWrapper
50
- _original_cls = MidmLMHeadModel
50
+ _hf_class = AutoModelForCausalLM
51
51
 
52
52
  @classmethod
53
53
  def from_pretrained(cls, *args, **kwargs):