optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +8 -1
  68. optimum/rbln/utils/logging.py +38 -1
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/METADATA +0 -106
  79. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  80. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
27
27
 
28
28
  import rebel
29
29
  import torch
30
+ from rebel.compile_context import CompileContext
30
31
  from transformers import (
31
32
  AutoModelForSpeechSeq2Seq,
32
33
  AutoProcessor,
@@ -37,23 +38,16 @@ from transformers import (
37
38
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
39
 
39
40
  from ....modeling import RBLNModel
40
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
41
42
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
43
  from .generation_whisper import RBLNWhisperGenerationMixin
43
- from .whisper_architecture import (
44
- _WhisperDecoderWrapper,
45
- _WhisperEncoderWrapper,
46
- )
44
+ from .whisper_architecture import WhisperWrapper
47
45
 
48
46
 
49
47
  logger = logging.getLogger(__name__)
50
48
 
51
49
  if TYPE_CHECKING:
52
- from transformers import (
53
- AutoFeatureExtractor,
54
- AutoProcessor,
55
- PretrainedConfig,
56
- )
50
+ from transformers import AutoFeatureExtractor, AutoProcessor, PretrainedConfig, PreTrainedModel
57
51
 
58
52
 
59
53
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
@@ -164,47 +158,51 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
164
158
  # TODO(jongho): implement
165
159
  raise NotImplementedError
166
160
 
161
+ @classmethod
162
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
163
+ rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
164
+ return WhisperWrapper(model, rbln_token_timestamps)
165
+
167
166
  @classmethod
168
167
  @torch.inference_mode()
169
168
  def get_compiled_model(cls, model, rbln_config: RBLNConfig):
170
- rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
171
- wrapped_encoder = _WhisperEncoderWrapper(model).eval()
172
- wrapped_decoder = _WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps).eval()
169
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
173
170
 
174
- enc_rbln_compile_config = rbln_config.compile_cfgs[0]
175
- dec_rbln_compile_config = rbln_config.compile_cfgs[1]
171
+ enc_compile_config = rbln_config.compile_cfgs[0]
172
+ dec_compile_config = rbln_config.compile_cfgs[1]
176
173
 
177
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
178
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
174
+ context = CompileContext(use_weight_sharing=False)
179
175
 
180
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
181
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
176
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
182
177
 
183
- enc_ir = rebel.torchscript_to_ir(
184
- enc_scripted_model,
185
- input_names=[v[0] for v in enc_rbln_compile_config.input_info],
186
- name=enc_rbln_compile_config.mod_name,
178
+ # Mark encoder's static tensors (cross kv states)
179
+ static_tensors = {}
180
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
181
+ if "key_value_states" in name:
182
+ static_tensors[name] = tensor
183
+ context.mark_static_address(tensor)
184
+
185
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
186
+
187
+ # Mark decoder's static tensors (self kv states)
188
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
189
+ if "key_value_states" in name:
190
+ context.mark_static_address(tensor)
191
+
192
+ compiled_encoder = super().compile(
193
+ wrapped_model.encoder,
194
+ enc_compile_config,
195
+ example_inputs=enc_example_inputs,
196
+ compile_context=context,
187
197
  )
188
- dec_ir = rebel.torchscript_to_ir(
189
- dec_scripted_model,
190
- input_names=[v[0] for v in dec_rbln_compile_config.input_info],
191
- name=dec_rbln_compile_config.mod_name,
198
+ compiled_decoder = super().compile(
199
+ wrapped_model.decoder,
200
+ dec_compile_config,
201
+ example_inputs=dec_example_inputs,
202
+ compile_context=context,
192
203
  )
193
204
 
194
- # Caching encoder/decoder I/O
195
- connections = [
196
- (enc_ir.outputs[0], dec_ir.inputs[4]),
197
- (dec_ir.outputs[1], dec_ir.inputs[3]),
198
- ]
199
- compiled_model = rebel.compile(
200
- enc_ir,
201
- dec_ir,
202
- connections=connections,
203
- fusion=enc_rbln_compile_config.fusion,
204
- npu=enc_rbln_compile_config.npu,
205
- tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
206
- )
207
- return compiled_model
205
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
208
206
 
209
207
  @classmethod
210
208
  def _get_rbln_config(
@@ -228,6 +226,22 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
228
226
 
229
227
  # model input info
230
228
  enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
229
+ enc_input_info.extend(
230
+ [
231
+ (
232
+ "cross_key_value_states",
233
+ [
234
+ model_config.decoder_layers * 2,
235
+ rbln_batch_size,
236
+ model_config.decoder_attention_heads,
237
+ enc_max_seq_len,
238
+ model_config.d_model // model_config.decoder_attention_heads,
239
+ ],
240
+ "float32",
241
+ )
242
+ ]
243
+ )
244
+
231
245
  dec_input_info = [
232
246
  ("decoder_input_ids", [rbln_batch_size, 1], "int64"),
233
247
  ("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
@@ -236,13 +250,13 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
236
250
  dec_input_info.extend(
237
251
  [
238
252
  (
239
- "self_key_value_states",
253
+ "cross_key_value_states",
240
254
  [
241
255
  model_config.decoder_layers * 2,
242
256
  rbln_batch_size,
243
257
  model_config.decoder_attention_heads,
244
- rbln_dec_max_seq_len,
245
- model_config.d_model // model_config.encoder_attention_heads,
258
+ enc_max_seq_len,
259
+ model_config.d_model // model_config.decoder_attention_heads,
246
260
  ],
247
261
  "float32",
248
262
  )
@@ -251,25 +265,25 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
251
265
  dec_input_info.extend(
252
266
  [
253
267
  (
254
- "cross_key_value_states",
268
+ f"self_key_value_states_{i}",
255
269
  [
256
- model_config.decoder_layers * 2,
257
270
  rbln_batch_size,
258
271
  model_config.decoder_attention_heads,
259
- enc_max_seq_len,
272
+ rbln_dec_max_seq_len,
260
273
  model_config.d_model // model_config.encoder_attention_heads,
261
274
  ],
262
275
  "float32",
263
276
  )
277
+ for i in range(model_config.decoder_layers * 2)
264
278
  ]
265
279
  )
266
280
 
267
- enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
268
- dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
281
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
282
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
269
283
 
270
284
  rbln_config = RBLNConfig(
271
285
  rbln_cls=cls.__name__,
272
- compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
286
+ compile_cfgs=[enc_compile_config, dec_compile_config],
273
287
  rbln_kwargs=rbln_kwargs,
274
288
  )
275
289
 
@@ -285,12 +299,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
285
299
 
286
300
  @classmethod
287
301
  def _create_runtimes(
288
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
302
+ cls,
303
+ compiled_models: List[rebel.RBLNCompiledModel],
304
+ rbln_device_map: Dict[str, int],
305
+ activate_profiler: Optional[bool] = None,
289
306
  ) -> List[rebel.Runtime]:
290
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
307
+ if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
308
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
309
+
291
310
  return [
292
- compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
293
- compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
311
+ compiled_models[0].create_runtime(
312
+ tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
313
+ ),
314
+ compiled_models[1].create_runtime(
315
+ tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
316
+ ),
294
317
  ]
295
318
 
296
319
  def prepare_inputs_for_generation(