optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -0,0 +1,498 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.utils import logging
29
+
30
+ from ....ops import register_rbln_custom_attention, register_rbln_custom_cache_update
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class Seq2SeqWrapper:
37
+ """A wrapper class for Seq2Seq models to support RBLN-specific optimizations.
38
+
39
+ This wrapper divides the Seq2Seq model into separate encoder and decoder wrappers,
40
+ enabling specific optimizations such as custom cache handling and attention mechanisms.
41
+
42
+ Args:
43
+ model (nn.Module): The Seq2Seq model to wrap.
44
+ enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
45
+ **kwargs: Additional arguments to pass to the decoder wrapper.
46
+ """
47
+
48
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
49
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
50
+ self.decoder = Seq2SeqDecoderWrapper(model, **kwargs)
51
+
52
+
53
+ class Seq2SeqEncoderWrapper(nn.Module):
54
+ """A wrapper for the encoder component of a Seq2Seq model, designed for RBLN optimization.
55
+
56
+ This wrapper modifies the standard encoder-decoder architecture of Seq2Seq models to optimize
57
+ memory usage and attention mechanisms, particularly in cross-attention layers. It supports custom
58
+ cache handling to improve performance during decoding.
59
+
60
+ Args:
61
+ model (nn.Module): The Seq2Seq model containing the encoder.
62
+ enc_max_seq_len (int): Maximum sequence length for encoder embeddings and cache sizes.
63
+ """
64
+
65
+ def __init__(self, model: nn.Module, enc_max_seq_len: int):
66
+ super().__init__()
67
+ register_rbln_custom_cache_update()
68
+ self.config = model.config
69
+ self.encoder = model.get_encoder()
70
+ self.encoder_max_length = enc_max_seq_len
71
+ self.__post_init__(model)
72
+
73
+ def __post_init__(self, model: nn.Module):
74
+ """
75
+ Post-initialization to extract and configure encoder-related attributes.
76
+
77
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
78
+ by subclasses to modify or add custom attributes as necessary.
79
+ """
80
+ self.n_layer = getattr(self.config, "decoder_layers", None)
81
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
82
+ self.num_heads = self.config.decoder_attention_heads
83
+ self.d_kv = self.config.d_model // self.num_heads
84
+
85
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
86
+ """
87
+ Extract cross-attention key and value projection layers from the decoder.
88
+ """
89
+ return (
90
+ nn.ModuleList(decoder_layers[i].encoder_attn.k_proj for i in range(self.n_layer)),
91
+ nn.ModuleList(decoder_layers[i].encoder_attn.v_proj for i in range(self.n_layer)),
92
+ )
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ attention_mask: torch.Tensor,
98
+ cross_key_values: torch.Tensor,
99
+ batch_position: torch.Tensor,
100
+ ) -> Tuple[torch.Tensor]:
101
+ # 1. get encoder last_hidden_states
102
+ encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
103
+ last_hidden_states = encoder_outputs[0]
104
+
105
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
106
+ cross_kv = []
107
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
108
+ past_k = (
109
+ k_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
110
+ )
111
+ past_v = (
112
+ v_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
113
+ )
114
+
115
+ cross_kv.append(past_k)
116
+ cross_kv.append(past_v)
117
+
118
+ cross_kv = torch.stack(cross_kv, dim=0)
119
+
120
+ # 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
121
+ batch_axis = torch.tensor(1, dtype=torch.int16)
122
+ cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
123
+ cross_key_values, cross_kv, batch_position, batch_axis
124
+ )
125
+
126
+ return cross_key_values
127
+
128
+
129
+ class Seq2SeqDecoderWrapper(nn.Module):
130
+ """
131
+ A wrapper for the decoder component of a Seq2Seq model, designed for RBLN optimization.
132
+
133
+ This wrapper handles tasks such as:
134
+ 1. Converting decoder components to support RBLN-specific conditional generation.
135
+ 2. Customizing attention mechanisms, including self-attention and cross-attention.
136
+ 3. Managing the decoder's key-value caches for both self and cross-attention.
137
+
138
+ Args:
139
+ model (nn.Module): The Seq2Seq model containing the decoder.
140
+ **kwargs: Additional arguments for decoder configuration.
141
+ """
142
+
143
+ def __init__(self, model: nn.Module, **kwargs):
144
+ super().__init__()
145
+ self.config = model.config
146
+ self.__post_init__(model, **kwargs)
147
+
148
+ def __post_init__(self, model: nn.Module, **kwargs):
149
+ """
150
+ Post-initialization to extract and configure encoder-related attributes.
151
+
152
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
153
+ by subclasses to modify or add custom attributes as necessary.
154
+ """
155
+ register_rbln_custom_attention()
156
+ self.num_layers = self.config.decoder_layers
157
+ self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
158
+
159
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
160
+ new_layers = []
161
+ for layer in model.get_decoder().layers:
162
+ self_attn = Seq2SeqSelfAttention(layer.self_attn)
163
+ new_layers.append(Seq2SeqDecoderLayer(layer, self_attn))
164
+
165
+ decoder_model = Seq2SeqDecoder(model.get_decoder(), new_layers)
166
+ new_model = Seq2SeqForConditionalGeneration(model, decoder_model)
167
+
168
+ return new_model
169
+
170
+ def forward(
171
+ self,
172
+ input_ids: torch.Tensor,
173
+ attention_mask: torch.Tensor,
174
+ encoder_attention_mask: torch.Tensor,
175
+ cache_position: torch.Tensor,
176
+ cross_kv_cache: torch.Tensor,
177
+ *self_kv_cache: torch.Tensor,
178
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
179
+ self_past_key_values = ()
180
+ cross_past_key_values = ()
181
+ for i in range(0, self.num_layers * 2, 2):
182
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
183
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
184
+
185
+ # decode
186
+ lm_logits, self_present_key_values = self.conditional_generation(
187
+ input_ids=input_ids,
188
+ attention_mask=attention_mask,
189
+ encoder_attention_mask=encoder_attention_mask,
190
+ self_past_key_values=self_past_key_values,
191
+ cross_past_key_values=cross_past_key_values,
192
+ cache_position=cache_position,
193
+ )
194
+
195
+ outputs = (lm_logits,) + self_present_key_values
196
+
197
+ return outputs
198
+
199
+
200
+ class Seq2SeqForConditionalGeneration(nn.Module):
201
+ """
202
+ A wrapper for Seq2Seq models supporting RBLN-specific optimizations for conditional generation.
203
+
204
+ This class adapts a Seq2Seq model for tasks like machine translation, summarization, or text generation
205
+ by:
206
+ 1. Wrapping and customizing the decoder component to support key RBLN features.
207
+ 2. Managing rescaling and output processing, if enabled.
208
+ 3. Aligning model behavior with RBLN's static and efficient execution requirements.
209
+
210
+ Attributes:
211
+ has_rescaling (bool): Indicates if output rescaling is applied.
212
+ config (PretrainedConfig): Configuration from the original Seq2Seq model.
213
+ lm_head (nn.Linear): The language modeling head for output logits.
214
+ decoder (nn.Module): The wrapped decoder model.
215
+ """
216
+
217
+ has_rescaling = False
218
+
219
+ def __init__(self, model, decoder_model):
220
+ super().__init__()
221
+ self.config = model.config
222
+ self.lm_head = model.lm_head
223
+ self.decoder = decoder_model
224
+ self.__post_init__()
225
+
226
+ def __post_init__(self):
227
+ """
228
+ Abstract method intended to be overridden by subclasses to modify or override
229
+ the attributes of the original model after initialization.
230
+ """
231
+
232
+ def forward(
233
+ self,
234
+ input_ids,
235
+ attention_mask,
236
+ encoder_attention_mask,
237
+ self_past_key_values,
238
+ cross_past_key_values,
239
+ cache_position,
240
+ ):
241
+ hidden_states, self_present_key_values = self.decoder(
242
+ input_ids=input_ids,
243
+ attention_mask=attention_mask,
244
+ encoder_attention_mask=encoder_attention_mask,
245
+ self_past_key_values=self_past_key_values,
246
+ cross_past_key_values=cross_past_key_values,
247
+ cache_position=cache_position,
248
+ )
249
+
250
+ if self.has_rescaling and self.config.tie_word_embeddings:
251
+ hidden_states = hidden_states * self.scaling
252
+
253
+ lm_logits = self.lm_head(hidden_states)
254
+
255
+ return lm_logits, self_present_key_values
256
+
257
+
258
+ class Seq2SeqDecoder(torch.nn.Module):
259
+ """A modified Seq2SeqDecoder implementation optimized for RBLN compilation.
260
+
261
+ Args:
262
+ model: Original Huggingface model to adapt
263
+ layers (List[Seq2SeqDecoderLayer]): Modified transformer layers optimized for RBLN
264
+ """
265
+
266
+ has_pos_emb = True
267
+
268
+ def __init__(self, model, layers, **kwargs):
269
+ super().__init__()
270
+ self._original_mod = model
271
+ self.layers = nn.ModuleList(layers)
272
+ self.embed_tokens = model.embed_tokens
273
+ self.final_layer_norm = getattr(model, "final_layer_norm", None)
274
+ self.__post_init__(**kwargs)
275
+
276
+ def __post_init__(self, **kwargs):
277
+ """
278
+ Abstract method intended to be overridden by subclasses to modify or override
279
+ the attributes of the original model after initialization.
280
+ """
281
+ pass
282
+
283
+ def get_embedding(self):
284
+ return self.embed_tokens
285
+
286
+ def prepare_attn_mask(self, *args, **kwargs):
287
+ raise NotImplementedError(
288
+ "The 'prepare_attn_mask' method is not implemented. Please define this method in a subclass."
289
+ )
290
+
291
+ def apply_position_embedding(self, *args, **kwargs):
292
+ raise NotImplementedError(
293
+ "The 'apply_position_embedding' method is not implemented. Please define this method in a subclass."
294
+ )
295
+
296
+ def forward(
297
+ self,
298
+ input_ids: torch.Tensor,
299
+ attention_mask: torch.Tensor,
300
+ encoder_attention_mask: torch.Tensor,
301
+ self_past_key_values: torch.Tensor,
302
+ cross_past_key_values: torch.Tensor,
303
+ cache_position: torch.Tensor,
304
+ ):
305
+ # embedding
306
+ hidden_states = self.get_embedding()(input_ids)
307
+ attention_mask, encoder_attention_mask = self.prepare_attn_mask(
308
+ attention_mask, encoder_attention_mask, cache_position=cache_position
309
+ )
310
+
311
+ if self.has_pos_emb:
312
+ hidden_states = self.apply_position_embedding(hidden_states, cache_position)
313
+
314
+ # iterate decoder_layer
315
+ self_present_key_values = ()
316
+ for decoder_layer, self_past_key_value, cross_past_key_value in zip(
317
+ self.layers, self_past_key_values, cross_past_key_values
318
+ ):
319
+ hidden_states, self_present_key_value = decoder_layer(
320
+ hidden_states,
321
+ attention_mask=attention_mask,
322
+ encoder_attention_mask=encoder_attention_mask,
323
+ self_past_key_value=self_past_key_value,
324
+ cross_past_key_value=cross_past_key_value,
325
+ cache_position=cache_position,
326
+ )
327
+ self_present_key_values += self_present_key_value
328
+
329
+ if self.final_layer_norm is not None:
330
+ hidden_states = self.final_layer_norm(hidden_states)
331
+
332
+ return hidden_states, self_present_key_values
333
+
334
+
335
+ class Seq2SeqDecoderLayer(torch.nn.Module):
336
+ """A modified decoder-only model implementation optimized for RBLN compilation.
337
+
338
+ Args:
339
+ model: Original Huggingface model to adapt
340
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
341
+ self_attn (Seq2SeqSelfAttention): Modified self-attention layer optimized for RBLN
342
+ """
343
+
344
+ def __init__(self, decoder_layer, self_attn):
345
+ super().__init__()
346
+ self._original_mod = decoder_layer
347
+ self.self_attn = self_attn
348
+ self.__post_init__()
349
+
350
+ def __post_init__(self, **kwargs):
351
+ """
352
+ Abstract method intended to be overridden by subclasses to modify or override
353
+ the attributes of the original model after initialization.
354
+ """
355
+ pass
356
+
357
+ def pre_self_attn_layer_norm(self, hidden_states):
358
+ raise NotImplementedError(
359
+ "The 'pre_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
360
+ )
361
+
362
+ def post_self_attn_layer_norm(self, hidden_states):
363
+ raise NotImplementedError(
364
+ "The 'post_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
365
+ )
366
+
367
+ def pre_cross_attn_layer_norm(self, hidden_states):
368
+ raise NotImplementedError(
369
+ "The 'pre_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
370
+ )
371
+
372
+ def post_cross_attn_layer_norm(self, hidden_states):
373
+ raise NotImplementedError(
374
+ "The 'post_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
375
+ )
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: torch.Tensor,
380
+ attention_mask: torch.Tensor,
381
+ encoder_attention_mask: torch.Tensor,
382
+ self_past_key_value: Tuple[torch.Tensor],
383
+ cross_past_key_value: Tuple[torch.Tensor],
384
+ cache_position: torch.Tensor,
385
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
386
+ dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
387
+
388
+ # Self Attention Block
389
+ residual = hidden_states
390
+ hidden_states = self.pre_self_attn_layer_norm(hidden_states)
391
+ hidden_states, self_attn_past_key_value = self.self_attn(
392
+ hidden_states=hidden_states,
393
+ past_key_value=self_past_key_value,
394
+ attention_mask=attention_mask,
395
+ cache_position=cache_position,
396
+ )
397
+ hidden_states = residual + hidden_states
398
+ hidden_states = self.post_self_attn_layer_norm(hidden_states)
399
+
400
+ # Cross-Attention Block
401
+ residual = hidden_states
402
+ hidden_states = self.pre_cross_attn_layer_norm(hidden_states)
403
+ cross_attn_output = self.encoder_attn(
404
+ hidden_states=hidden_states,
405
+ past_key_value=cross_past_key_value,
406
+ attention_mask=encoder_attention_mask,
407
+ key_value_states=dummy_encoder_hidden_states,
408
+ )
409
+ hidden_states = residual + cross_attn_output[0]
410
+ hidden_states = self.post_cross_attn_layer_norm(hidden_states)
411
+
412
+ # Feed-Forward Block
413
+ hidden_states = self.ff_layer(hidden_states)
414
+
415
+ return hidden_states, self_attn_past_key_value
416
+
417
+
418
+ class Seq2SeqSelfAttention(nn.Module):
419
+ def __init__(self, attn):
420
+ super().__init__()
421
+ self._original_mod = attn
422
+ self.__post_init__()
423
+
424
+ def __post_init__(self, **kwargs):
425
+ """
426
+ Abstract method intended to be overridden by subclasses to modify or override
427
+ the attributes of the original model after initialization.
428
+ """
429
+ pass
430
+
431
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
432
+ return tensor.view(bsz, 1, seq_len, 1, self.num_heads, self.head_dim).transpose(2, 4)
433
+
434
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
435
+ """Projects input hidden states into query, key, and value representations.
436
+
437
+ Args:
438
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
439
+
440
+ Returns:
441
+ Tuple of (query_states, key_states, value_states)
442
+ """
443
+ query_states = self.q_proj(hidden_states)
444
+ key_states = self.k_proj(hidden_states)
445
+ value_states = self.v_proj(hidden_states)
446
+ return query_states, key_states, value_states
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.Tensor,
451
+ past_key_value: Tuple[torch.Tensor],
452
+ attention_mask: torch.Tensor,
453
+ cache_position: torch.Tensor,
454
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
455
+ bsz, tgt_len, _ = hidden_states.size()
456
+
457
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
458
+ query_states = self._shape(query_states, tgt_len, bsz)
459
+ key_states = self._shape(key_states, -1, bsz)
460
+ value_states = self._shape(value_states, -1, bsz)
461
+
462
+ all_key_states = []
463
+ all_value_states = []
464
+ all_attn_output = []
465
+ for b_idx in range(bsz):
466
+ query_state = query_states[b_idx]
467
+ key_state = key_states[b_idx]
468
+ value_state = value_states[b_idx]
469
+ attn_mask = attention_mask[b_idx].unsqueeze(0).unsqueeze(2)
470
+ past_key_state = past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim)
471
+ past_value_state = past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim)
472
+
473
+ attn_output, key_state, value_state = self.attn_decode(
474
+ query_state,
475
+ key_state,
476
+ value_state,
477
+ attn_mask,
478
+ past_key_state,
479
+ past_value_state,
480
+ cache_position[b_idx][0],
481
+ torch.tensor(1.0, dtype=torch.float32), # scale
482
+ )
483
+
484
+ attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim).transpose(1, 2)
485
+ attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
486
+
487
+ all_key_states.append(key_state.squeeze(2))
488
+ all_value_states.append(value_state.squeeze(2))
489
+ all_attn_output.append(attn_output)
490
+
491
+ key_states = torch.cat(all_key_states, dim=0)
492
+ value_states = torch.cat(all_value_states, dim=0)
493
+ attn_output = torch.cat(all_attn_output, dim=0)
494
+
495
+ attn_output = self.out_proj(attn_output)
496
+ present_key_value = (key_states, value_states)
497
+
498
+ return attn_output, present_key_value
@@ -22,4 +22,3 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
25
- from .t5_architecture import T5DecoderWrapper, T5EncoderWrapper
@@ -22,17 +22,23 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import inspect
25
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
26
26
 
27
+ import torch
28
+ import transformers
27
29
  from transformers import (
28
30
  AutoModelForTextEncoding,
29
31
  PretrainedConfig,
32
+ T5EncoderModel,
30
33
  T5ForConditionalGeneration,
31
34
  )
35
+ from transformers.modeling_outputs import BaseModelOutput
32
36
 
33
- from ....modeling_base import RBLNModel
37
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin
38
+ from ....modeling import RBLNModel
34
39
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
35
40
  from ....utils.logging import get_logger
41
+ from ....utils.runtime_utils import RBLNPytorchRuntime
36
42
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
37
43
  from .t5_architecture import T5Wrapper
38
44
 
@@ -43,8 +49,60 @@ if TYPE_CHECKING:
43
49
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
44
50
 
45
51
 
52
+ class RBLNRuntimeModel(RBLNPytorchRuntime):
53
+ def forward(
54
+ self,
55
+ input_ids: torch.LongTensor,
56
+ attention_mask: torch.FloatTensor,
57
+ head_mask: torch.FloatTensor,
58
+ inputs_embeds: torch.FloatTensor,
59
+ **kwargs,
60
+ ):
61
+ return super().forward(
62
+ input_ids,
63
+ attention_mask,
64
+ head_mask,
65
+ inputs_embeds,
66
+ **kwargs,
67
+ )
68
+
69
+
70
+ class T5EncoderWrapper(torch.nn.Module):
71
+ def __init__(self, model: "T5EncoderModel") -> None:
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ def forward(self, *args, **kwargs):
76
+ kwargs.pop("return_dict", None)
77
+ return self.model(*args, **kwargs, return_dict=False)
78
+
79
+
46
80
  class RBLNT5EncoderModel(RBLNModel):
47
81
  auto_model_class = AutoModelForTextEncoding
82
+ rbln_model_input_names = ["input_ids", "attention_mask"]
83
+
84
+ def __post_init__(self, **kwargs):
85
+ self.model = RBLNRuntimeModel(runtime=self.model[0])
86
+
87
+ @classmethod
88
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
89
+ return T5EncoderWrapper(model)
90
+
91
+ @classmethod
92
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
93
+ batch_size = rbln_config.get("batch_size", 1)
94
+ max_sequence_length = rbln_config.get("max_sequence_length", 256)
95
+ model_input_names = ["input_ids"]
96
+
97
+ rbln_config.update(
98
+ {
99
+ "batch_size": batch_size,
100
+ "max_seq_len": max_sequence_length,
101
+ "model_input_names": model_input_names,
102
+ }
103
+ )
104
+
105
+ return rbln_config
48
106
 
49
107
  @classmethod
50
108
  def _get_rbln_config(
@@ -54,6 +112,7 @@ class RBLNT5EncoderModel(RBLNModel):
54
112
  rbln_kwargs: Dict[str, Any] = {},
55
113
  ) -> RBLNConfig:
56
114
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
115
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
57
116
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
58
117
 
59
118
  max_position_embeddings = getattr(model_config, "n_positions", None)
@@ -71,12 +130,27 @@ class RBLNT5EncoderModel(RBLNModel):
71
130
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
72
131
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
73
132
 
133
+ if rbln_model_input_names is None:
134
+ for tokenizer in preprocessors:
135
+ if hasattr(tokenizer, "model_input_names"):
136
+ rbln_model_input_names = tokenizer.model_input_names
137
+ break
138
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
139
+ rbln_model_input_names = cls.rbln_model_input_names
140
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
141
+ original_model_class = getattr(transformers, model_config.architectures[0])
142
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
143
+ raise ValueError(
144
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
145
+ f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(input_names_order)})"
146
+ )
147
+
74
148
  if rbln_batch_size is None:
75
149
  rbln_batch_size = 1
76
150
 
77
151
  input_info = [
78
- ("input_ids", [rbln_batch_size, rbln_max_seq_len], "int64"),
79
- ("attention_mask", [rbln_batch_size, rbln_max_seq_len], "int64"),
152
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
153
+ for model_input_name in rbln_model_input_names
80
154
  ]
81
155
 
82
156
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
@@ -90,11 +164,38 @@ class RBLNT5EncoderModel(RBLNModel):
90
164
  rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
91
165
  return rbln_config
92
166
 
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.LongTensor] = None,
170
+ attention_mask: Optional[torch.FloatTensor] = None,
171
+ head_mask: Optional[torch.FloatTensor] = None,
172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
173
+ output_attentions: Optional[bool] = None,
174
+ output_hidden_states: Optional[bool] = None,
175
+ return_dict: Optional[bool] = None,
176
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
177
+ encoder_outputs = self.model(
178
+ input_ids=input_ids,
179
+ attention_mask=attention_mask,
180
+ inputs_embeds=inputs_embeds,
181
+ head_mask=head_mask,
182
+ output_attentions=output_attentions,
183
+ output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict,
185
+ )
186
+ if not return_dict:
187
+ return (encoder_outputs,)
188
+ else:
189
+ return BaseModelOutput(last_hidden_state=encoder_outputs)
190
+
93
191
 
94
192
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
95
193
  @classmethod
96
194
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
97
- return T5Wrapper(model)
195
+ enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
196
+ dec_max_seq_len = rbln_config.model_cfg["dec_max_seq_len"]
197
+
198
+ return T5Wrapper(model, enc_max_seq_len=enc_max_seq_len, dec_max_seq_len=dec_max_seq_len)
98
199
 
99
200
  def __getattr__(self, __name: str) -> Any:
100
201
  def redirect(func):