optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +7 -0
  68. optimum/rbln/utils/logging.py +37 -0
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  79. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,498 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.utils import logging
29
+
30
+ from ....ops import register_rbln_custom_attention, register_rbln_custom_cache_update
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class Seq2SeqWrapper:
37
+ """A wrapper class for Seq2Seq models to support RBLN-specific optimizations.
38
+
39
+ This wrapper divides the Seq2Seq model into separate encoder and decoder wrappers,
40
+ enabling specific optimizations such as custom cache handling and attention mechanisms.
41
+
42
+ Args:
43
+ model (nn.Module): The Seq2Seq model to wrap.
44
+ enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
45
+ **kwargs: Additional arguments to pass to the decoder wrapper.
46
+ """
47
+
48
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
49
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
50
+ self.decoder = Seq2SeqDecoderWrapper(model, **kwargs)
51
+
52
+
53
+ class Seq2SeqEncoderWrapper(nn.Module):
54
+ """A wrapper for the encoder component of a Seq2Seq model, designed for RBLN optimization.
55
+
56
+ This wrapper modifies the standard encoder-decoder architecture of Seq2Seq models to optimize
57
+ memory usage and attention mechanisms, particularly in cross-attention layers. It supports custom
58
+ cache handling to improve performance during decoding.
59
+
60
+ Args:
61
+ model (nn.Module): The Seq2Seq model containing the encoder.
62
+ enc_max_seq_len (int): Maximum sequence length for encoder embeddings and cache sizes.
63
+ """
64
+
65
+ def __init__(self, model: nn.Module, enc_max_seq_len: int):
66
+ super().__init__()
67
+ register_rbln_custom_cache_update()
68
+ self.config = model.config
69
+ self.encoder = model.get_encoder()
70
+ self.encoder_max_length = enc_max_seq_len
71
+ self.__post_init__(model)
72
+
73
+ def __post_init__(self, model: nn.Module):
74
+ """
75
+ Post-initialization to extract and configure encoder-related attributes.
76
+
77
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
78
+ by subclasses to modify or add custom attributes as necessary.
79
+ """
80
+ self.n_layer = getattr(self.config, "decoder_layers", None)
81
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
82
+ self.num_heads = self.config.decoder_attention_heads
83
+ self.d_kv = self.config.d_model // self.num_heads
84
+
85
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
86
+ """
87
+ Extract cross-attention key and value projection layers from the decoder.
88
+ """
89
+ return (
90
+ nn.ModuleList(decoder_layers[i].encoder_attn.k_proj for i in range(self.n_layer)),
91
+ nn.ModuleList(decoder_layers[i].encoder_attn.v_proj for i in range(self.n_layer)),
92
+ )
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ attention_mask: torch.Tensor,
98
+ cross_key_values: torch.Tensor,
99
+ batch_position: torch.Tensor,
100
+ ) -> Tuple[torch.Tensor]:
101
+ # 1. get encoder last_hidden_states
102
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
103
+ last_hidden_states = encoder_outputs[0]
104
+
105
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
106
+ cross_kv = []
107
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
108
+ past_k = (
109
+ k_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
110
+ )
111
+ past_v = (
112
+ v_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
113
+ )
114
+
115
+ cross_kv.append(past_k)
116
+ cross_kv.append(past_v)
117
+
118
+ cross_kv = torch.stack(cross_kv, dim=0)
119
+
120
+ # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
121
+ batch_axis = torch.tensor(1, dtype=torch.int16)
122
+ cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
123
+ cross_key_values, cross_kv, batch_position, batch_axis
124
+ )
125
+
126
+ return cross_key_values
127
+
128
+
129
+ class Seq2SeqDecoderWrapper(nn.Module):
130
+ """
131
+ A wrapper for the decoder component of a Seq2Seq model, designed for RBLN optimization.
132
+
133
+ This wrapper handles tasks such as:
134
+ 1. Converting decoder components to support RBLN-specific conditional generation.
135
+ 2. Customizing attention mechanisms, including self-attention and cross-attention.
136
+ 3. Managing the decoder's key-value caches for both self and cross-attention.
137
+
138
+ Args:
139
+ model (nn.Module): The Seq2Seq model containing the decoder.
140
+ **kwargs: Additional arguments for decoder configuration.
141
+ """
142
+
143
+ def __init__(self, model: nn.Module, **kwargs):
144
+ super().__init__()
145
+ self.config = model.config
146
+ self.__post_init__(model, **kwargs)
147
+
148
+ def __post_init__(self, model: nn.Module, **kwargs):
149
+ """
150
+ Post-initialization to extract and configure encoder-related attributes.
151
+
152
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
153
+ by subclasses to modify or add custom attributes as necessary.
154
+ """
155
+ register_rbln_custom_attention()
156
+ self.num_layers = self.config.decoder_layers
157
+ self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
158
+
159
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
160
+ new_layers = []
161
+ for layer in model.get_decoder().layers:
162
+ self_attn = Seq2SeqSelfAttention(layer.self_attn)
163
+ new_layers.append(Seq2SeqDecoderLayer(layer, self_attn))
164
+
165
+ decoder_model = Seq2SeqDecoder(model.get_decoder(), new_layers)
166
+ new_model = Seq2SeqForConditionalGeneration(model, decoder_model)
167
+
168
+ return new_model
169
+
170
+ def forward(
171
+ self,
172
+ input_ids: torch.Tensor,
173
+ attention_mask: torch.Tensor,
174
+ encoder_attention_mask: torch.Tensor,
175
+ cache_position: torch.Tensor,
176
+ cross_kv_cache: torch.Tensor,
177
+ *self_kv_cache: torch.Tensor,
178
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
179
+ self_past_key_values = ()
180
+ cross_past_key_values = ()
181
+ for i in range(0, self.num_layers * 2, 2):
182
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
183
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
184
+
185
+ # decode
186
+ lm_logits, self_present_key_values = self.conditional_generation(
187
+ input_ids=input_ids,
188
+ attention_mask=attention_mask,
189
+ encoder_attention_mask=encoder_attention_mask,
190
+ self_past_key_values=self_past_key_values,
191
+ cross_past_key_values=cross_past_key_values,
192
+ cache_position=cache_position,
193
+ )
194
+
195
+ outputs = (lm_logits,) + self_present_key_values
196
+
197
+ return outputs
198
+
199
+
200
+ class Seq2SeqForConditionalGeneration(nn.Module):
201
+ """
202
+ A wrapper for Seq2Seq models supporting RBLN-specific optimizations for conditional generation.
203
+
204
+ This class adapts a Seq2Seq model for tasks like machine translation, summarization, or text generation
205
+ by:
206
+ 1. Wrapping and customizing the decoder component to support key RBLN features.
207
+ 2. Managing rescaling and output processing, if enabled.
208
+ 3. Aligning model behavior with RBLN's static and efficient execution requirements.
209
+
210
+ Attributes:
211
+ has_rescaling (bool): Indicates if output rescaling is applied.
212
+ config (PretrainedConfig): Configuration from the original Seq2Seq model.
213
+ lm_head (nn.Linear): The language modeling head for output logits.
214
+ decoder (nn.Module): The wrapped decoder model.
215
+ """
216
+
217
+ has_rescaling = False
218
+
219
+ def __init__(self, model, decoder_model):
220
+ super().__init__()
221
+ self.config = model.config
222
+ self.lm_head = model.lm_head
223
+ self.decoder = decoder_model
224
+ self.__post_init__()
225
+
226
+ def __post_init__(self):
227
+ """
228
+ Abstract method intended to be overridden by subclasses to modify or override
229
+ the attributes of the original model after initialization.
230
+ """
231
+
232
+ def forward(
233
+ self,
234
+ input_ids,
235
+ attention_mask,
236
+ encoder_attention_mask,
237
+ self_past_key_values,
238
+ cross_past_key_values,
239
+ cache_position,
240
+ ):
241
+ hidden_states, self_present_key_values = self.decoder(
242
+ input_ids=input_ids,
243
+ attention_mask=attention_mask,
244
+ encoder_attention_mask=encoder_attention_mask,
245
+ self_past_key_values=self_past_key_values,
246
+ cross_past_key_values=cross_past_key_values,
247
+ cache_position=cache_position,
248
+ )
249
+
250
+ if self.has_rescaling and self.config.tie_word_embeddings:
251
+ hidden_states = hidden_states * self.scaling
252
+
253
+ lm_logits = self.lm_head(hidden_states)
254
+
255
+ return lm_logits, self_present_key_values
256
+
257
+
258
+ class Seq2SeqDecoder(torch.nn.Module):
259
+ """A modified Seq2SeqDecoder implementation optimized for RBLN compilation.
260
+
261
+ Args:
262
+ model: Original Huggingface model to adapt
263
+ layers (List[Seq2SeqDecoderLayer]): Modified transformer layers optimized for RBLN
264
+ """
265
+
266
+ has_pos_emb = True
267
+
268
+ def __init__(self, model, layers, **kwargs):
269
+ super().__init__()
270
+ self._original_mod = model
271
+ self.layers = nn.ModuleList(layers)
272
+ self.embed_tokens = model.embed_tokens
273
+ self.final_layer_norm = getattr(model, "final_layer_norm", None)
274
+ self.__post_init__(**kwargs)
275
+
276
+ def __post_init__(self, **kwargs):
277
+ """
278
+ Abstract method intended to be overridden by subclasses to modify or override
279
+ the attributes of the original model after initialization.
280
+ """
281
+ pass
282
+
283
+ def get_embedding(self):
284
+ return self.embed_tokens
285
+
286
+ def prepare_attn_mask(self, *args, **kwargs):
287
+ raise NotImplementedError(
288
+ "The 'prepare_attn_mask' method is not implemented. Please define this method in a subclass."
289
+ )
290
+
291
+ def apply_position_embedding(self, *args, **kwargs):
292
+ raise NotImplementedError(
293
+ "The 'apply_position_embedding' method is not implemented. Please define this method in a subclass."
294
+ )
295
+
296
+ def forward(
297
+ self,
298
+ input_ids: torch.Tensor,
299
+ attention_mask: torch.Tensor,
300
+ encoder_attention_mask: torch.Tensor,
301
+ self_past_key_values: torch.Tensor,
302
+ cross_past_key_values: torch.Tensor,
303
+ cache_position: torch.Tensor,
304
+ ):
305
+ # embedding
306
+ hidden_states = self.get_embedding()(input_ids)
307
+ attention_mask, encoder_attention_mask = self.prepare_attn_mask(
308
+ attention_mask, encoder_attention_mask, cache_position=cache_position
309
+ )
310
+
311
+ if self.has_pos_emb:
312
+ hidden_states = self.apply_position_embedding(hidden_states, cache_position)
313
+
314
+ # iterate decoder_layer
315
+ self_present_key_values = ()
316
+ for decoder_layer, self_past_key_value, cross_past_key_value in zip(
317
+ self.layers, self_past_key_values, cross_past_key_values
318
+ ):
319
+ hidden_states, self_present_key_value = decoder_layer(
320
+ hidden_states,
321
+ attention_mask=attention_mask,
322
+ encoder_attention_mask=encoder_attention_mask,
323
+ self_past_key_value=self_past_key_value,
324
+ cross_past_key_value=cross_past_key_value,
325
+ cache_position=cache_position,
326
+ )
327
+ self_present_key_values += self_present_key_value
328
+
329
+ if self.final_layer_norm is not None:
330
+ hidden_states = self.final_layer_norm(hidden_states)
331
+
332
+ return hidden_states, self_present_key_values
333
+
334
+
335
+ class Seq2SeqDecoderLayer(torch.nn.Module):
336
+ """A modified decoder-only model implementation optimized for RBLN compilation.
337
+
338
+ Args:
339
+ model: Original Huggingface model to adapt
340
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
341
+ self_attn (Seq2SeqSelfAttention): Modified self-attention layer optimized for RBLN
342
+ """
343
+
344
+ def __init__(self, decoder_layer, self_attn):
345
+ super().__init__()
346
+ self._original_mod = decoder_layer
347
+ self.self_attn = self_attn
348
+ self.__post_init__()
349
+
350
+ def __post_init__(self, **kwargs):
351
+ """
352
+ Abstract method intended to be overridden by subclasses to modify or override
353
+ the attributes of the original model after initialization.
354
+ """
355
+ pass
356
+
357
+ def pre_self_attn_layer_norm(self, hidden_states):
358
+ raise NotImplementedError(
359
+ "The 'pre_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
360
+ )
361
+
362
+ def post_self_attn_layer_norm(self, hidden_states):
363
+ raise NotImplementedError(
364
+ "The 'post_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
365
+ )
366
+
367
+ def pre_cross_attn_layer_norm(self, hidden_states):
368
+ raise NotImplementedError(
369
+ "The 'pre_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
370
+ )
371
+
372
+ def post_cross_attn_layer_norm(self, hidden_states):
373
+ raise NotImplementedError(
374
+ "The 'post_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
375
+ )
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: torch.Tensor,
380
+ attention_mask: torch.Tensor,
381
+ encoder_attention_mask: torch.Tensor,
382
+ self_past_key_value: Tuple[torch.Tensor],
383
+ cross_past_key_value: Tuple[torch.Tensor],
384
+ cache_position: torch.Tensor,
385
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
386
+ dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
387
+
388
+ # Self Attention Block
389
+ residual = hidden_states
390
+ hidden_states = self.pre_self_attn_layer_norm(hidden_states)
391
+ hidden_states, self_attn_past_key_value = self.self_attn(
392
+ hidden_states=hidden_states,
393
+ past_key_value=self_past_key_value,
394
+ attention_mask=attention_mask,
395
+ cache_position=cache_position,
396
+ )
397
+ hidden_states = residual + hidden_states
398
+ hidden_states = self.post_self_attn_layer_norm(hidden_states)
399
+
400
+ # Cross-Attention Block
401
+ residual = hidden_states
402
+ hidden_states = self.pre_cross_attn_layer_norm(hidden_states)
403
+ cross_attn_output = self.encoder_attn(
404
+ hidden_states=hidden_states,
405
+ past_key_value=cross_past_key_value,
406
+ attention_mask=encoder_attention_mask,
407
+ key_value_states=dummy_encoder_hidden_states,
408
+ )
409
+ hidden_states = residual + cross_attn_output[0]
410
+ hidden_states = self.post_cross_attn_layer_norm(hidden_states)
411
+
412
+ # Feed-Forward Block
413
+ hidden_states = self.ff_layer(hidden_states)
414
+
415
+ return hidden_states, self_attn_past_key_value
416
+
417
+
418
+ class Seq2SeqSelfAttention(nn.Module):
419
+ def __init__(self, attn):
420
+ super().__init__()
421
+ self._original_mod = attn
422
+ self.__post_init__()
423
+
424
+ def __post_init__(self, **kwargs):
425
+ """
426
+ Abstract method intended to be overridden by subclasses to modify or override
427
+ the attributes of the original model after initialization.
428
+ """
429
+ pass
430
+
431
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
432
+ return tensor.view(bsz, 1, seq_len, 1, self.num_heads, self.head_dim).transpose(2, 4)
433
+
434
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
435
+ """Projects input hidden states into query, key, and value representations.
436
+
437
+ Args:
438
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
439
+
440
+ Returns:
441
+ Tuple of (query_states, key_states, value_states)
442
+ """
443
+ query_states = self.q_proj(hidden_states)
444
+ key_states = self.k_proj(hidden_states)
445
+ value_states = self.v_proj(hidden_states)
446
+ return query_states, key_states, value_states
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ past_key_value: Tuple[torch.Tensor],
452
+ attention_mask: torch.Tensor,
453
+ cache_position: torch.Tensor,
454
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
455
+ bsz, tgt_len, _ = hidden_states.size()
456
+
457
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
458
+ query_states = self._shape(query_states, tgt_len, bsz)
459
+ key_states = self._shape(key_states, -1, bsz)
460
+ value_states = self._shape(value_states, -1, bsz)
461
+
462
+ all_key_states = []
463
+ all_value_states = []
464
+ all_attn_output = []
465
+ for b_idx in range(bsz):
466
+ query_state = query_states[b_idx]
467
+ key_state = key_states[b_idx]
468
+ value_state = value_states[b_idx]
469
+ attn_mask = attention_mask[b_idx].unsqueeze(0).unsqueeze(2)
470
+ past_key_state = past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim)
471
+ past_value_state = past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim)
472
+
473
+ attn_output, key_state, value_state = self.attn_decode(
474
+ query_state,
475
+ key_state,
476
+ value_state,
477
+ attn_mask,
478
+ past_key_state,
479
+ past_value_state,
480
+ cache_position[b_idx][0],
481
+ torch.tensor(1.0, dtype=torch.float32), # scale
482
+ )
483
+
484
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim).transpose(1, 2)
485
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
486
+
487
+ all_key_states.append(key_state.squeeze(2))
488
+ all_value_states.append(value_state.squeeze(2))
489
+ all_attn_output.append(attn_output)
490
+
491
+ key_states = torch.cat(all_key_states, dim=0)
492
+ value_states = torch.cat(all_value_states, dim=0)
493
+ attn_output = torch.cat(all_attn_output, dim=0)
494
+
495
+ attn_output = self.out_proj(attn_output)
496
+ present_key_value = (key_states, value_states)
497
+
498
+ return attn_output, present_key_value
@@ -22,4 +22,3 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
25
- from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
@@ -34,9 +34,9 @@ from transformers import (
34
34
  )
35
35
  from transformers.modeling_outputs import BaseModelOutput
36
36
 
37
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
37
38
  from ....modeling import RBLNModel
38
39
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
39
- from ....modeling_diffusers import RBLNDiffusionMixin
40
40
  from ....utils.logging import get_logger
41
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
42
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
@@ -192,7 +192,10 @@ class RBLNT5EncoderModel(RBLNModel):
192
192
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
193
193
  @classmethod
194
194
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
195
- return T5Wrapper(model)
195
+ enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
196
+ dec_max_seq_len = rbln_config.model_cfg["dec_max_seq_len"]
197
+
198
+ return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
196
199
 
197
200
  def __getattr__(self, __name: str) -> Any:
198
201
  def redirect(func):