optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +5 -1
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
  4. optimum/rbln/diffusers/models/controlnet.py +36 -56
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
  16. optimum/rbln/modeling_base.py +12 -5
  17. optimum/rbln/modeling_diffusers.py +400 -0
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/cache_utils.py +5 -9
  20. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  21. optimum/rbln/transformers/models/__init__.py +80 -31
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
  23. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  25. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
  26. optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
  27. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  29. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  30. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  31. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  32. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
  33. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
  35. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  36. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  37. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  38. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  39. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  40. optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  42. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  43. optimum/rbln/utils/context.py +58 -0
  44. optimum/rbln/utils/decorator_utils.py +55 -0
  45. optimum/rbln/utils/import_utils.py +7 -0
  46. optimum/rbln/utils/runtime_utils.py +4 -4
  47. optimum/rbln/utils/timer_utils.py +2 -2
  48. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
  49. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
  50. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import PhiForCausalLM
29
-
30
- from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
24
+ from ....utils import logging
25
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .phi_architecture import PhiWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Phi Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
@@ -50,20 +40,4 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return PhiWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(PhiForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(
65
- inspect.signature(val).parameters
66
- ):
67
- return redirect(val)
68
-
69
- return val
43
+ _decoder_wrapper_cls = PhiWrapper
@@ -33,46 +33,12 @@ from transformers.modeling_outputs import (
33
33
  from ...cache_utils import RebelDynamicCache
34
34
  from ..decoderonly import (
35
35
  DecoderOnlyWrapper,
36
- DynamicNTKScalingRotaryEmbedding,
37
- LinearScalingRotaryEmbedding,
38
- RotaryEmbedding,
39
36
  apply_rotary_pos_emb,
40
37
  slice_and_unsqueeze_cos_sin,
41
38
  )
42
39
 
43
40
 
44
41
  class PhiWrapper(DecoderOnlyWrapper):
45
- def _init_rope(self):
46
- if self.rope_scaling is None:
47
- rotary_emb = RotaryEmbedding(
48
- int(self.config.partial_rotary_factor * self.head_dim),
49
- max_position_embeddings=self.max_position_embeddings,
50
- base=self.config.rope_theta,
51
- )
52
- else:
53
- scaling_type = self.rope_scaling["type"]
54
- scaling_factor = self.rope_scaling["factor"]
55
- if scaling_type == "linear":
56
- rotary_emb = LinearScalingRotaryEmbedding(
57
- int(self.config.partial_rotary_factor * self.head_dim),
58
- max_position_embeddings=self.max_position_embeddings,
59
- scaling_factor=scaling_factor,
60
- base=self.config.rope_theta,
61
- max_seq_len=self.max_seq_len,
62
- )
63
- elif scaling_type == "dynamic":
64
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
65
- int(self.config.partial_rotary_factor * self.head_dim),
66
- max_position_embeddings=self.max_position_embeddings,
67
- scaling_factor=scaling_factor,
68
- base=self.config.rope_theta,
69
- max_seq_len=self.max_seq_len,
70
- )
71
- else:
72
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
73
-
74
- return rotary_emb
75
-
76
42
  def get_forward_dict(self):
77
43
  forward_dict = {}
78
44
  forward_dict.update(
@@ -86,6 +52,41 @@ class PhiWrapper(DecoderOnlyWrapper):
86
52
 
87
53
 
88
54
  class PhiAttention:
55
+ def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
56
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
57
+ key_state = key_state.unsqueeze(2)
58
+ value_state = value_state.unsqueeze(2)
59
+ attn_mask = attn_mask.unsqueeze(2)
60
+
61
+ query_state = query_state.view(
62
+ 1,
63
+ self.num_key_value_heads,
64
+ self.num_heads // self.num_key_value_heads,
65
+ -1,
66
+ self.head_dim,
67
+ )
68
+
69
+ key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
70
+
71
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
72
+ attn_weights = torch.matmul(
73
+ query_state.to(torch.float32),
74
+ key_state.to(torch.float32).transpose(3, 4),
75
+ ) / math.sqrt(self.head_dim)
76
+ attn_weights = attn_weights + attn_mask
77
+
78
+ # upcast attention to fp32
79
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
80
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
81
+ attn_output = torch.matmul(attn_weights, value_state)
82
+
83
+ # reshape for removing repeat_kv
84
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
85
+ attn_output = attn_output.transpose(1, 2).contiguous()
86
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
87
+
88
+ return attn_output, key_state, value_state
89
+
89
90
  def forward(
90
91
  self,
91
92
  hidden_states: torch.Tensor,
@@ -95,7 +96,6 @@ class PhiAttention:
95
96
  output_attentions: bool = False,
96
97
  cos: Optional[torch.Tensor] = None,
97
98
  sin: Optional[torch.Tensor] = None,
98
- rotary_pos_emb=None,
99
99
  **kwargs,
100
100
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
101
101
  bsz, q_len, _ = hidden_states.size()
@@ -108,24 +108,18 @@ class PhiAttention:
108
108
  query_states = self.q_layernorm(query_states)
109
109
  key_states = self.k_layernorm(key_states)
110
110
 
111
- query_states = query_states.view(
112
- bsz, q_len, self.num_heads, self.head_dim
113
- ).transpose(1, 2)
114
- key_states = key_states.view(
115
- bsz, q_len, self.num_key_value_heads, self.head_dim
116
- ).transpose(1, 2)
117
- value_states = value_states.view(
118
- bsz, q_len, self.num_key_value_heads, self.head_dim
119
- ).transpose(1, 2)
111
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
112
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
113
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
120
114
 
121
115
  # Partial rotary embedding
122
116
  query_rot, query_pass = (
123
- query_states[..., : rotary_pos_emb.dim],
124
- query_states[..., rotary_pos_emb.dim :],
117
+ query_states[..., : self.rotary_ndims],
118
+ query_states[..., self.rotary_ndims :],
125
119
  )
126
120
  key_rot, key_pass = (
127
- key_states[..., : rotary_pos_emb.dim],
128
- key_states[..., rotary_pos_emb.dim :],
121
+ key_states[..., : self.rotary_ndims],
122
+ key_states[..., self.rotary_ndims :],
129
123
  )
130
124
 
131
125
  # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
@@ -135,113 +129,38 @@ class PhiAttention:
135
129
  query_states = torch.cat((query_rot, query_pass), dim=-1)
136
130
  key_states = torch.cat((key_rot, key_pass), dim=-1)
137
131
 
138
- # Decoder
139
- if (batch_index is None or batch_index == -1) and bsz > 1:
140
- all_key_states = []
141
- all_value_states = []
142
- all_attn_output = []
143
-
132
+ # Decoder (bsz > 1)
133
+ if bsz > 1:
134
+ iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
144
135
  for b in range(bsz):
145
- query_state = query_states[b].unsqueeze(0)
146
- attn_mask = attention_mask[b].unsqueeze(0)
147
- key_state = key_states[b].unsqueeze(0)
148
- value_state = value_states[b].unsqueeze(0)
149
-
150
- # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
151
- key_state = key_state.unsqueeze(2)
152
- value_state = value_state.unsqueeze(2)
153
- attn_mask = attn_mask.unsqueeze(2)
154
-
155
- query_state = query_state.view(
156
- 1,
157
- self.num_key_value_heads,
158
- self.num_heads // self.num_key_value_heads,
159
- q_len,
160
- self.head_dim,
136
+ attn_output, key_state, value_state = PhiAttention._attn(
137
+ self,
138
+ query_states[b].unsqueeze(0),
139
+ key_states[b].unsqueeze(0),
140
+ value_states[b].unsqueeze(0),
141
+ attention_mask[b].unsqueeze(0),
142
+ past_key_value,
143
+ batch_idx=b,
144
+ is_prefill=False,
161
145
  )
162
-
163
- key_state, value_state = past_key_value.update(
164
- key_state,
165
- value_state,
166
- self.layer_idx,
167
- b,
168
- )
169
-
170
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
171
- attn_weights = torch.matmul(
172
- query_state.to(torch.float32),
173
- key_state.to(torch.float32).transpose(3, 4),
174
- ) / math.sqrt(self.head_dim)
175
- attn_weights = attn_weights + attn_mask
176
-
177
- # upcast attention to fp32
178
- attn_weights = nn.functional.softmax(
179
- attn_weights, dim=-1, dtype=torch.float32
180
- ).to(query_states.dtype)
181
- attn_weights = nn.functional.dropout(
182
- attn_weights, p=self.attention_dropout, training=self.training
183
- )
184
- attn_output = torch.matmul(attn_weights, value_state)
185
-
186
- # reshape for removing repeat_kv
187
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
188
- attn_output = attn_output.transpose(1, 2).contiguous()
189
- attn_output = attn_output.reshape(
190
- 1, q_len, self.num_heads * self.head_dim
191
- )
192
-
193
- all_key_states.append(key_state)
194
- all_value_states.append(value_state)
195
- all_attn_output.append(attn_output)
196
-
197
- key_states = torch.cat(all_key_states, dim=0)
198
- value_states = torch.cat(all_value_states, dim=0)
199
- attn_output = torch.cat(all_attn_output, dim=0)
146
+ iterate_results["key_states"].append(key_state)
147
+ iterate_results["value_states"].append(value_state)
148
+ iterate_results["attn_output"].append(attn_output)
149
+
150
+ key_states = torch.cat(iterate_results["key_states"], dim=0)
151
+ value_states = torch.cat(iterate_results["value_states"], dim=0)
152
+ attn_output = torch.cat(iterate_results["attn_output"], dim=0)
153
+ # Prefill & Decoder (bsz == 1)
200
154
  else:
201
- if batch_index is None or batch_index == -1:
202
- batch_index = 0
203
-
204
- # reshape for removing repeat_kv
205
- key_states = key_states.unsqueeze(2)
206
- value_states = value_states.unsqueeze(2)
207
- attention_mask = attention_mask.unsqueeze(2)
208
- query_states = query_states.view(
209
- 1,
210
- self.num_key_value_heads,
211
- self.num_heads // self.num_key_value_heads,
212
- q_len,
213
- self.head_dim,
214
- )
215
-
216
- key_states, value_states = past_key_value.update(
155
+ attn_output, key_states, value_states = PhiAttention._attn(
156
+ self,
157
+ query_states,
217
158
  key_states,
218
159
  value_states,
219
- self.layer_idx,
220
- batch_index,
221
- read_first_step=True,
222
- )
223
-
224
- # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
225
- attn_weights = torch.matmul(
226
- query_states.to(torch.float32),
227
- key_states.to(torch.float32).transpose(3, 4),
228
- ) / math.sqrt(self.head_dim)
229
- attn_weights = attn_weights + attention_mask
230
-
231
- # upcast attention to fp32
232
- attn_weights = torch.nn.functional.softmax(
233
- attn_weights, dim=-1, dtype=torch.float32
234
- ).to(value_states.dtype)
235
- attn_weights = torch.nn.functional.dropout(
236
- attn_weights, p=self.attention_dropout, training=self.training
237
- )
238
- attn_output = torch.matmul(attn_weights, value_states)
239
-
240
- # reshape for removing repeat_kv
241
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
242
- attn_output = attn_output.transpose(1, 2).contiguous()
243
- attn_output = attn_output.reshape(
244
- bsz, q_len, self.num_heads * self.head_dim
160
+ attention_mask,
161
+ past_key_value,
162
+ batch_idx=batch_index,
163
+ is_prefill=True,
245
164
  )
246
165
 
247
166
  attn_output = self.dense(attn_output)
@@ -265,12 +184,9 @@ class PhiDecoderLayer:
265
184
  batch_ids: Optional[torch.LongTensor] = None,
266
185
  cos: Optional[torch.Tensor] = None,
267
186
  sin: Optional[torch.Tensor] = None,
268
- rotary_pos_emb=None,
269
187
  forward_dict: Optional[Dict[str, classmethod]] = None,
270
188
  **kwargs,
271
- ) -> Tuple[
272
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
273
- ]:
189
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
274
190
  """
275
191
  Args:
276
192
  hidden_states (`torch.FloatTensor`):
@@ -294,9 +210,7 @@ class PhiDecoderLayer:
294
210
  hidden_states = self.input_layernorm(hidden_states)
295
211
 
296
212
  # Self Attention
297
- attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
298
- "decoder_layer"
299
- ](
213
+ attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
300
214
  self.self_attn,
301
215
  hidden_states=hidden_states,
302
216
  attention_mask=attention_mask,
@@ -307,7 +221,6 @@ class PhiDecoderLayer:
307
221
  use_cache=use_cache,
308
222
  cos=cos,
309
223
  sin=sin,
310
- rotary_pos_emb=rotary_pos_emb,
311
224
  **kwargs,
312
225
  )
313
226
  past_key_value.assign(key_states, value_states, layer_idx)
@@ -339,6 +252,8 @@ class PhiModel:
339
252
  use_cache: Optional[bool] = True,
340
253
  output_attentions: Optional[bool] = False,
341
254
  output_hidden_states: Optional[bool] = False,
255
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
256
+ kvcache_partition_size: Optional[torch.Tensor] = None,
342
257
  forward_dict: Optional[Dict[str, classmethod]] = None,
343
258
  rotary_pos_emb=None,
344
259
  ) -> BaseModelOutputWithPast:
@@ -378,7 +293,8 @@ class PhiModel:
378
293
  batch_ids=batch_ids,
379
294
  cos=cos,
380
295
  sin=sin,
381
- rotary_pos_emb=rotary_pos_emb,
296
+ cache_pos_for_partitions=cache_pos_for_partitions,
297
+ kvcache_partition_size=kvcache_partition_size,
382
298
  forward_dict=forward_dict,
383
299
  )
384
300
 
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import Qwen2ForCausalLM
29
-
30
- from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
24
+ from ....utils import logging
25
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .qwen2_architecture import QWEN2Wrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return QWEN2Wrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(Qwen2ForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = QWEN2Wrapper
@@ -21,5 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling_t5 import RBLNT5ForConditionalGeneration
24
+ from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
25
25
  from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
@@ -22,11 +22,16 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import inspect
25
- from typing import TYPE_CHECKING, Any, Callable
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
26
 
27
- from transformers import T5ForConditionalGeneration
27
+ from transformers import (
28
+ AutoModelForTextEncoding,
29
+ PretrainedConfig,
30
+ T5ForConditionalGeneration,
31
+ )
28
32
 
29
- from ....modeling_config import RBLNConfig
33
+ from ....modeling_base import RBLNModel
34
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
30
35
  from ....utils.logging import get_logger
31
36
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
32
37
  from .t5_architecture import T5Wrapper
@@ -35,7 +40,55 @@ from .t5_architecture import T5Wrapper
35
40
  logger = get_logger()
36
41
 
37
42
  if TYPE_CHECKING:
38
- from transformers import PreTrainedModel
43
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
44
+
45
+
46
+ class RBLNT5EncoderModel(RBLNModel):
47
+ auto_model_class = AutoModelForTextEncoding
48
+
49
+ @classmethod
50
+ def _get_rbln_config(
51
+ cls,
52
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
53
+ model_config: Optional["PretrainedConfig"] = None,
54
+ rbln_kwargs: Dict[str, Any] = {},
55
+ ) -> RBLNConfig:
56
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
57
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
58
+
59
+ max_position_embeddings = getattr(model_config, "n_positions", None)
60
+
61
+ if rbln_max_seq_len is None:
62
+ rbln_max_seq_len = max_position_embeddings
63
+ if rbln_max_seq_len is None:
64
+ for tokenizer in preprocessors:
65
+ if hasattr(tokenizer, "model_max_length"):
66
+ rbln_max_seq_len = tokenizer.model_max_length
67
+ break
68
+ if rbln_max_seq_len is None:
69
+ raise ValueError("`rbln_max_seq_len` should be specified!")
70
+
71
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
72
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
73
+
74
+ if rbln_batch_size is None:
75
+ rbln_batch_size = 1
76
+
77
+ input_info = [
78
+ ("input_ids", [rbln_batch_size, rbln_max_seq_len], "int64"),
79
+ ("attention_mask", [rbln_batch_size, rbln_max_seq_len], "int64"),
80
+ ]
81
+
82
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
83
+
84
+ rbln_config = RBLNConfig(
85
+ rbln_cls=cls.__name__,
86
+ compile_cfgs=[rbln_compile_config],
87
+ rbln_kwargs=rbln_kwargs,
88
+ )
89
+
90
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
91
+ return rbln_config
39
92
 
40
93
 
41
94
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
@@ -102,7 +102,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
102
102
  class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
103
103
  """
104
104
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
105
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
105
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
106
106
 
107
107
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
108
108
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -31,8 +31,13 @@ from torch.nn import functional as F
31
31
 
32
32
  # Constants
33
33
  QUANTIZED_WEIGHTS = {
34
- "q_proj", "k_proj", "v_proj", "o_proj",
35
- "gate_proj", "up_proj", "down_proj",
34
+ "q_proj",
35
+ "k_proj",
36
+ "v_proj",
37
+ "o_proj",
38
+ "gate_proj",
39
+ "up_proj",
40
+ "down_proj",
36
41
  }
37
42
 
38
43
 
@@ -81,6 +86,7 @@ def create_qlinear(layer: Linear) -> Linear:
81
86
  """
82
87
  Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
83
88
  """
89
+
84
90
  def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
85
91
  if inputs.dtype != self.scales.dtype:
86
92
  raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
@@ -0,0 +1,58 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from contextlib import contextmanager
25
+ from pathlib import Path
26
+ from typing import Union
27
+
28
+ from optimum.exporters import TasksManager
29
+ from transformers import AutoConfig, AutoModel
30
+
31
+
32
+ @contextmanager
33
+ def override_auto_classes(config_func=None, model_func=None, skip_taskmanager=True):
34
+ """Temporarily override Auto classes with original model classes"""
35
+ original_config = AutoConfig.from_pretrained
36
+ original_model = AutoModel.from_pretrained
37
+ original_get_model_from_task = TasksManager.get_model_from_task
38
+
39
+ def get_model_from_task(
40
+ task: str,
41
+ model_name_or_path: Union[str, Path],
42
+ **kwargs,
43
+ ):
44
+ return model_func(model_name_or_path, **kwargs)
45
+
46
+ def none_func(*args, **kwargs):
47
+ return None
48
+
49
+ try:
50
+ AutoConfig.from_pretrained = config_func or none_func
51
+ AutoModel.from_pretrained = model_func or none_func
52
+ if skip_taskmanager:
53
+ TasksManager.get_model_from_task = none_func if model_func is None else get_model_from_task
54
+ yield
55
+ finally:
56
+ AutoConfig.from_pretrained = original_config
57
+ AutoModel.from_pretrained = original_model
58
+ TasksManager.get_model_from_task = original_get_model_from_task
@@ -0,0 +1,55 @@
1
+ from functools import wraps
2
+
3
+ from .logging import get_logger
4
+
5
+
6
+ logger = get_logger(__name__)
7
+
8
+
9
+ def remove_compile_time_kwargs(func):
10
+ """
11
+ Decorator to handle compile-time parameters during inference.
12
+
13
+ For RBLN-optimized pipelines, several parameters must be determined during compilation
14
+ and cannot be modified during inference. This decorator:
15
+ 1. Removes and warns about LoRA scale in cross_attention_kwargs
16
+ 2. Removes and warns about image dimension parameters (height, width)
17
+
18
+ Args:
19
+ func: The pipeline's __call__ method to be wrapped
20
+ """
21
+
22
+ @wraps(func)
23
+ def wrapper(self, *args, **kwargs):
24
+ height_exists = "height" in kwargs and kwargs["height"] is not None
25
+ width_exists = "width" in kwargs and kwargs["width"] is not None
26
+ if height_exists or width_exists:
27
+ logger.warning(
28
+ "Image dimension parameters (`height`, `width`) will be ignored during inference. "
29
+ "Image dimensions must be specified during model compilation using from_pretrained()."
30
+ )
31
+ kwargs.pop("width", None)
32
+ kwargs.pop("height", None)
33
+
34
+ if "cross_attention_kwargs" in kwargs:
35
+ cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
36
+ if not cross_attention_kwargs:
37
+ return func(self, *args, **kwargs)
38
+
39
+ has_scale = "scale" in cross_attention_kwargs
40
+ if has_scale:
41
+ logger.warning(
42
+ "LoRA scale in cross_attention_kwargs will be ignored during inference. "
43
+ "To adjust LoRA scale, specify it during model compilation using from_pretrained()."
44
+ )
45
+
46
+ # If scale is the only key, set to None
47
+ # Otherwise, remove scale and preserve other settings
48
+ if len(cross_attention_kwargs) == 1:
49
+ kwargs["cross_attention_kwargs"] = None
50
+ else:
51
+ kwargs["cross_attention_kwargs"].pop("scale")
52
+
53
+ return func(self, *args, **kwargs)
54
+
55
+ return wrapper
@@ -37,6 +37,13 @@ class VersionCompat:
37
37
 
38
38
 
39
39
  RBLN_VERSION_COMPATS = {
40
+ "0.1.13": [
41
+ VersionCompat(
42
+ package_name="rebel-compiler",
43
+ min_version="0.6.0",
44
+ max_version="0.6.1",
45
+ ),
46
+ ],
40
47
  "0.1.12": [
41
48
  VersionCompat(
42
49
  package_name="rebel-compiler",