optimum-rbln 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +0 -1
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +7 -0
  68. optimum/rbln/utils/logging.py +37 -0
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/METADATA +37 -26
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  79. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +0 -0
@@ -26,13 +26,14 @@ import logging
26
26
  from abc import ABC
27
27
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
28
28
 
29
- import rebel # noqa: F401
30
- import torch # noqa: F401
29
+ import rebel
30
+ import torch
31
+ from rebel.compile_context import CompileContext
31
32
  from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
32
33
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
33
34
 
34
35
  from ....modeling import RBLNModel
35
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
36
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
38
 
38
39
 
@@ -66,7 +67,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
66
67
  class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
67
68
  """
68
69
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
69
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
70
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
70
71
 
71
72
  A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
72
73
  It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
@@ -88,49 +89,42 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
88
89
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
89
90
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
90
91
 
91
- wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
92
- wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
92
+ enc_compile_config = rbln_config.compile_cfgs[0]
93
+ dec_compile_config = rbln_config.compile_cfgs[1]
93
94
 
94
- wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
95
- wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
95
+ context = CompileContext(use_weight_sharing=False)
96
96
 
97
- enc_rbln_compile_config = rbln_config.compile_cfgs[0]
98
- dec_rbln_compile_config = rbln_config.compile_cfgs[1]
97
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
99
98
 
100
- enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
101
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
99
+ # Mark encoder's static tensors (cross kv states)
100
+ static_tensors = {}
101
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
102
+ if "key_value_states" in name:
103
+ static_tensors[name] = tensor
104
+ context.mark_static_address(tensor)
102
105
 
103
- enc_example_inputs[3].fill_(0)
104
- dec_example_inputs[4].fill_(-1)
106
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
105
107
 
106
- enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
107
- dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
108
+ # Mark decoder's static tensors (self kv states)
109
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
110
+ if "key_value_states" in name:
111
+ context.mark_static_address(tensor)
108
112
 
109
- enc_ir = rebel.torchscript_to_ir(
110
- enc_scripted_model,
111
- input_names=[v[0] for v in enc_rbln_compile_config.input_info],
112
- name=enc_rbln_compile_config.mod_name,
113
+ compiled_encoder = super().compile(
114
+ wrapped_model.encoder,
115
+ enc_compile_config,
116
+ example_inputs=enc_example_inputs,
117
+ compile_context=context,
113
118
  )
114
- dec_ir = rebel.torchscript_to_ir(
115
- dec_scripted_model,
116
- input_names=[v[0] for v in dec_rbln_compile_config.input_info],
117
- name=dec_rbln_compile_config.mod_name,
118
- )
119
-
120
- connections = [
121
- (enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
122
- (dec_ir.outputs[1], dec_ir.inputs[5]),
123
- ]
124
119
 
125
- compiled_model = rebel.compile(
126
- enc_ir,
127
- dec_ir,
128
- connections=connections,
129
- fusion=enc_rbln_compile_config.fusion,
130
- npu=enc_rbln_compile_config.npu,
131
- tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
120
+ compiled_decoder = super().compile(
121
+ wrapped_model.decoder,
122
+ dec_compile_config,
123
+ example_inputs=dec_example_inputs,
124
+ compile_context=context,
132
125
  )
133
- return compiled_model
126
+
127
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
134
128
 
135
129
  @classmethod
136
130
  def _get_rbln_config(
@@ -204,7 +198,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
204
198
  ],
205
199
  "float32",
206
200
  ),
207
- ("batch_idx", [], "int32"),
201
+ ("batch_position", [], "int16"),
208
202
  ]
209
203
 
210
204
  dec_input_info = [
@@ -216,17 +210,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
216
210
  [rbln_batch_size, 1],
217
211
  "int32",
218
212
  ),
219
- ("batch_position", [], "int32"),
220
213
  ]
221
214
  dec_input_info.extend(
222
215
  [
223
216
  (
224
- "self_key_value_states",
217
+ "cross_key_value_states",
225
218
  [
226
219
  n_layer * 2,
227
220
  rbln_batch_size,
228
221
  n_head,
229
- rbln_dec_max_seq_len,
222
+ rbln_enc_max_seq_len,
230
223
  d_kv,
231
224
  ],
232
225
  "float32",
@@ -236,24 +229,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
236
229
  dec_input_info.extend(
237
230
  [
238
231
  (
239
- "cross_key_value_states",
232
+ f"self_key_value_states_{i}",
240
233
  [
241
- n_layer * 2,
242
234
  rbln_batch_size,
243
235
  n_head,
244
- rbln_enc_max_seq_len,
236
+ rbln_dec_max_seq_len,
245
237
  d_kv,
246
238
  ],
247
239
  "float32",
248
240
  )
241
+ for i in range(n_layer * 2)
249
242
  ]
250
243
  )
251
- enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
252
- dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
244
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
245
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
253
246
 
254
247
  rbln_config = RBLNConfig(
255
248
  rbln_cls=cls.__name__,
256
- compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
249
+ compile_cfgs=[enc_compile_config, dec_compile_config],
257
250
  rbln_kwargs=rbln_kwargs,
258
251
  )
259
252
 
@@ -270,12 +263,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
270
263
 
271
264
  @classmethod
272
265
  def _create_runtimes(
273
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
266
+ cls,
267
+ compiled_models: List[rebel.RBLNCompiledModel],
268
+ rbln_device_map: Dict[str, int],
269
+ activate_profiler: Optional[bool] = None,
274
270
  ) -> List[rebel.Runtime]:
275
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
271
+ if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
272
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
273
+
276
274
  return [
277
- compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
278
- compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
275
+ compiled_models[0].create_runtime(
276
+ tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
277
+ ),
278
+ compiled_models[1].create_runtime(
279
+ tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
280
+ ),
279
281
  ]
280
282
 
281
283
  def can_generate(self):
@@ -340,9 +342,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
340
342
  attention_mask=dec_attention_mask,
341
343
  encoder_attention_mask=attention_mask,
342
344
  cache_position=cache_position,
343
- batch_position=torch.tensor(0, dtype=torch.int32),
344
345
  )
345
- lm_logits = decoder_output.logits[0]
346
+ lm_logits = decoder_output.logits
346
347
 
347
348
  return Seq2SeqLMOutput(logits=lm_logits)
348
349
 
@@ -381,15 +382,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
381
382
  )
382
383
 
383
384
  # 3. make sure that encoder returns `ModelOutput`
384
- model_input_name = model_input_name if model_input_name is not None else self.main_input_name
385
385
  encoder_kwargs["return_dict"] = True
386
386
  encoder_kwargs["output_hidden_states"] = False
387
387
  encoder_kwargs["output_attentions"] = False
388
388
 
389
389
  for b in range(batch_size):
390
- batch_idx = torch.tensor(b, dtype=torch.int32)
390
+ batch_position = torch.tensor(b, dtype=torch.int16)
391
391
  encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
392
392
  encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
393
- model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
393
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_position=batch_position)
394
394
 
395
395
  return model_kwargs