optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -14,13 +14,14 @@
14
14
  import importlib
15
15
  import inspect
16
16
  import warnings
17
- from typing import Type
17
+ from pathlib import Path
18
+ from typing import Any, Dict, Optional, Type, Union
18
19
 
19
- from transformers import AutoConfig, PretrainedConfig
20
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
20
21
  from transformers.dynamic_module_utils import get_class_from_dynamic_module
21
22
  from transformers.models.auto.auto_factory import _get_model_class
22
23
 
23
- from optimum.rbln.configuration_utils import RBLNAutoConfig
24
+ from optimum.rbln.configuration_utils import RBLNAutoConfig, RBLNModelConfig
24
25
  from optimum.rbln.modeling_base import RBLNBaseModel
25
26
  from optimum.rbln.utils.model_utils import (
26
27
  MODEL_MAPPING,
@@ -43,10 +44,10 @@ class _BaseAutoModelClass:
43
44
  @classmethod
44
45
  def get_rbln_cls(
45
46
  cls,
46
- pretrained_model_name_or_path,
47
- *args,
48
- export=True,
49
- **kwargs,
47
+ pretrained_model_name_or_path: Union[str, Path],
48
+ *args: Any,
49
+ export: bool = None,
50
+ **kwargs: Any,
50
51
  ):
51
52
  """
52
53
  Determine the appropriate RBLN model class based on the given model ID and configuration.
@@ -59,6 +60,20 @@ class _BaseAutoModelClass:
59
60
  Returns:
60
61
  RBLNBaseModel: The corresponding RBLN model class.
61
62
  """
63
+ if isinstance(pretrained_model_name_or_path, Path):
64
+ pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
65
+
66
+ if export is None:
67
+ export = not RBLNBaseModel._is_compiled(
68
+ model_id=pretrained_model_name_or_path,
69
+ token=kwargs.get("token"),
70
+ revision=kwargs.get("revision"),
71
+ force_download=kwargs.get("force_download", False),
72
+ cache_dir=kwargs.get("cache_dir"),
73
+ subfolder=kwargs.get("subfolder", ""),
74
+ local_files_only=kwargs.get("local_files_only", False),
75
+ )
76
+
62
77
  if export:
63
78
  hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
64
79
  rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
@@ -85,9 +100,9 @@ class _BaseAutoModelClass:
85
100
  @classmethod
86
101
  def infer_hf_model_class(
87
102
  cls,
88
- pretrained_model_name_or_path,
89
- *args,
90
- **kwargs,
103
+ pretrained_model_name_or_path: Union[str, Path],
104
+ *args: Any,
105
+ **kwargs: Any,
91
106
  ):
92
107
  """
93
108
  Infer the HuggingFace model class based on the configuration or model name.
@@ -140,7 +155,7 @@ class _BaseAutoModelClass:
140
155
  return model_class
141
156
 
142
157
  @classmethod
143
- def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
158
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
144
159
  """
145
160
  Retrieve the path to the compiled model directory for a given RBLN model.
146
161
 
@@ -163,17 +178,77 @@ class _BaseAutoModelClass:
163
178
  return rbln_config.rbln_model_cls_name
164
179
 
165
180
  @classmethod
166
- def from_pretrained(cls, model_id, *args, **kwargs):
167
- rbln_cls = cls.get_rbln_cls(model_id, *args, **kwargs)
168
- return rbln_cls.from_pretrained(model_id, *args, **kwargs)
181
+ def from_pretrained(
182
+ cls,
183
+ model_id: Union[str, Path],
184
+ export: bool = None,
185
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
186
+ **kwargs,
187
+ ):
188
+ """
189
+ Load an RBLN-accelerated model from a pretrained checkpoint or a compiled RBLN artifact.
190
+
191
+ This convenience method determines the concrete `RBLN*` model class that matches the
192
+ underlying HuggingFace architecture and dispatches to that class's
193
+ `from_pretrained()` implementation. Depending on whether a compiled RBLN folder is
194
+ detected (or if `export=True` is passed), it will either:
195
+
196
+ - Compile from a HuggingFace checkpoint to an RBLN model
197
+ - Or load an already-compiled RBLN model directory/repository
198
+
199
+ Args:
200
+ model_id:
201
+ HF repo id or local path. For compiled models, this should point to a directory
202
+ (optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
203
+ export:
204
+ Force compilation from a HuggingFace checkpoint. When `None`, this is inferred by
205
+ checking whether compiled artifacts exist at `model_id`.
206
+ rbln_config:
207
+ RBLN compilation/runtime configuration. May be provided as a dictionary or as an
208
+ instance of the specific model's config class (e.g., `RBLNLlamaForCausalLMConfig`).
209
+ kwargs: Additional keyword arguments.
210
+ - Arguments prefixed with `rbln_` are forwarded to the RBLN config.
211
+ - Remaining arguments are forwarded to the HuggingFace loader (e.g., `revision`,
212
+ `token`, `trust_remote_code`, `cache_dir`, `subfolder`, `local_files_only`).
213
+
214
+ Returns:
215
+ An instantiated RBLN model ready for inference on RBLN NPUs.
216
+ """
217
+ rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
218
+ return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
169
219
 
170
220
  @classmethod
171
- def from_model(cls, model, *args, **kwargs):
221
+ def from_model(
222
+ cls,
223
+ model: PreTrainedModel,
224
+ config: Optional[PretrainedConfig] = None,
225
+ rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
226
+ **kwargs: Any,
227
+ ) -> RBLNBaseModel:
228
+ """
229
+ Convert and compile an in-memory HuggingFace model into an RBLN model.
230
+
231
+ This method resolves the appropriate concrete `RBLN*` class from the input model's class
232
+ name (e.g., `LlamaForCausalLM` -> `RBLNLlamaForCausalLM`) and then delegates to that
233
+ class's `from_model()` implementation.
234
+
235
+ Args:
236
+ model: A HuggingFace model instance to convert.
237
+ config: The configuration object associated with the model.
238
+ rbln_config:
239
+ RBLN compilation/runtime configuration. May be provided as a dictionary or as an
240
+ instance of the specific model's config class.
241
+ kwargs: Additional keyword arguments.
242
+ - Arguments prefixed with `rbln_` are forwarded to the RBLN config.
243
+
244
+ Returns:
245
+ An instantiated RBLN model ready for inference on RBLN NPUs.
246
+ """
172
247
  rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
173
- return rbln_cls.from_model(model, *args, **kwargs)
248
+ return rbln_cls.from_model(model, config=config, rbln_config=rbln_config, **kwargs)
174
249
 
175
250
  @staticmethod
176
- def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
251
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok: bool = False):
177
252
  """
178
253
  Register a new RBLN model class.
179
254
 
@@ -35,8 +35,12 @@ from transformers.models.auto.modeling_auto import (
35
35
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
36
36
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
37
37
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
38
+ MODEL_FOR_TEXT_ENCODING_MAPPING,
39
+ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
38
40
  MODEL_FOR_VISION_2_SEQ_MAPPING,
39
41
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
42
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
43
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
40
44
  MODEL_MAPPING,
41
45
  MODEL_MAPPING_NAMES,
42
46
  )
@@ -53,65 +57,106 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
53
57
 
54
58
 
55
59
  class RBLNAutoModel(_BaseAutoModelClass):
60
+ """Automatically detect all supported transformers models."""
61
+
56
62
  _model_mapping = MODEL_MAPPING
57
63
  _model_mapping_names = MODEL_MAPPING_NAMES
58
64
 
59
65
 
60
66
  class RBLNAutoModelForCTC(_BaseAutoModelClass):
67
+ """Automatically detect Connectionist Temporal Classification (CTC) head Models."""
68
+
61
69
  _model_mapping = MODEL_FOR_CTC_MAPPING
62
70
  _model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
63
71
 
64
72
 
65
73
  class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
74
+ """Automatically detect Casual Language Models."""
75
+
76
+ """"""
66
77
  _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
67
78
  _model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
68
79
 
69
80
 
70
81
  class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
82
+ """Automatically detect Sequence to Sequence Language Models."""
83
+
71
84
  _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
72
85
  _model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
73
86
 
74
87
 
75
88
  class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
89
+ """Automatically detect Sequence to Sequence Generation Models."""
90
+
76
91
  _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
77
92
  _model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
78
93
 
79
94
 
80
95
  class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
96
+ """Automatically detect Speech Sequence to Sequence Language Models."""
97
+
81
98
  _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
82
99
  _model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
83
100
 
84
101
 
85
102
  class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
103
+ """Automatically detect Sequence Classification Models."""
104
+
86
105
  _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
87
106
  _model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
88
107
 
89
108
 
90
109
  class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
110
+ """Automatically detect Vision to Sequence Generation Models."""
111
+
91
112
  _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
92
113
  _model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
93
114
 
94
115
 
95
116
  class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
117
+ """Automatically detect Image and Text to Text Generation Models."""
118
+
96
119
  _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
97
120
  _model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
98
121
 
99
122
 
100
123
  class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
124
+ """Automatically detect Masked Lanuage Models."""
125
+
101
126
  _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
102
127
  _model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
103
128
 
104
129
 
105
130
  class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
131
+ """Automatically detect Audio Classification Models."""
132
+
106
133
  _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
107
134
  _model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
108
135
 
109
136
 
110
137
  class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
138
+ """Automatically detect Image Classification Models."""
139
+
111
140
  _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
112
141
  _model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
113
142
 
114
143
 
115
144
  class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
145
+ """Automatically detect Question Answering Models."""
146
+
116
147
  _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
117
148
  _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
149
+
150
+
151
+ class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
152
+ """Automatically detect Text Encoding Models."""
153
+
154
+ _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
155
+ _model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
156
+
157
+
158
+ class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
159
+ """Automatically detect Zero Shot Object Detection Models."""
160
+
161
+ _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
162
+ _model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
@@ -16,9 +16,7 @@ from typing import Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_attention_mask,
21
- )
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
22
20
  from transformers.utils import logging
23
21
 
24
22
  from ..seq2seq.seq2seq_architecture import (
@@ -32,3 +32,5 @@ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized BART models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = True
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+
4
+ class BertModelWrapper(torch.nn.Module):
5
+ def __init__(self, model, rbln_config):
6
+ super().__init__()
7
+ self.model = model
8
+ self.rbln_config = rbln_config
9
+
10
+ def forward(self, *args, **kwargs):
11
+ output = self.model(*args, **kwargs)
12
+ if isinstance(output, torch.Tensor):
13
+ return output
14
+ elif isinstance(output, tuple):
15
+ return tuple(x for x in output if x is not None)
16
+ return output
@@ -12,15 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ....utils.logging import get_logger
15
+ import torch
16
+
16
17
  from ...modeling_generic import (
17
18
  RBLNModelForMaskedLM,
18
19
  RBLNModelForQuestionAnswering,
19
20
  RBLNTransformerEncoderForFeatureExtraction,
20
21
  )
21
-
22
-
23
- logger = get_logger(__name__)
22
+ from .bert_architecture import BertModelWrapper
23
+ from .configuration_bert import RBLNBertModelConfig
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
@@ -34,6 +34,10 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
34
34
 
35
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
36
36
 
37
+ @classmethod
38
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
39
+ return BertModelWrapper(model, rbln_config)
40
+
37
41
 
38
42
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
43
  """
@@ -12,9 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNBlip2VisionModelConfig(RBLNModelConfig):
@@ -25,6 +29,16 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
25
29
  RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
26
30
  """
27
31
 
32
+ def __init__(
33
+ self,
34
+ batch_size: Optional[int] = None,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
41
+
28
42
 
29
43
  class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
30
44
  """
@@ -36,24 +50,34 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
36
50
 
37
51
  def __init__(
38
52
  self,
53
+ batch_size: Optional[int] = None,
39
54
  num_query_tokens: Optional[int] = None,
40
55
  image_text_hidden_size: Optional[int] = None,
41
56
  **kwargs,
42
57
  ):
43
58
  """
44
59
  Args:
45
- batch_size (Optional[int]): The batch size for inference. Defaults to 1.
46
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
47
-
48
- Raises:
49
- ValueError: If batch_size is not a positive integer.
60
+ num_query_tokens (Optional[int]): The number of query tokens passed through the Transformer.
61
+ image_text_hidden_size (Optional[int]): Dimensionality of the hidden state of the image-text fusion layer.
62
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
50
63
  """
51
64
  super().__init__(**kwargs)
65
+ self.batch_size = batch_size or 1
66
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
67
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
68
+
52
69
  self.num_query_tokens = num_query_tokens
53
70
  self.image_text_hidden_size = image_text_hidden_size
54
71
 
55
72
 
56
73
  class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
74
+ """
75
+ Configuration class for RBLNBlip2ForConditionalGeneration.
76
+
77
+ This configuration class stores the configuration parameters specific to
78
+ RBLN-optimized BLIP-2 models for conditional generation tasks that involve both image and text inputs.
79
+ """
80
+
57
81
  submodules = ["vision_model", "qformer", "language_model"]
58
82
 
59
83
  def __init__(
@@ -62,14 +86,15 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
62
86
  vision_model: Optional[RBLNModelConfig] = None,
63
87
  qformer: Optional[RBLNModelConfig] = None,
64
88
  language_model: Optional[RBLNModelConfig] = None,
65
- **kwargs: Dict[str, Any],
89
+ **kwargs: Any,
66
90
  ):
67
91
  """
68
92
  Args:
69
93
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
70
94
  vision_model (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
95
+ qformer (Optional[RBLNModelConfig]): Configuration for the RBLN-optimized BLIP-2 Q-Former model.
71
96
  language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
72
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
97
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
73
98
 
74
99
  Raises:
75
100
  ValueError: If batch_size is not a positive integer.
@@ -79,6 +104,12 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
79
104
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
80
105
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
81
106
 
82
- self.vision_model = self.init_submodule_config(RBLNBlip2VisionModelConfig, vision_model)
83
- self.language_model = language_model
84
- self.qformer = self.init_submodule_config(RBLNBlip2QFormerModelConfig, qformer)
107
+ if self.batch_size != 1:
108
+ logger.warning("Ignore batch_size for Blip2 vision model. It will be set to 1.")
109
+ logger.warning("Ignore batch_size for Blip2 qformer. It will be set to 1.")
110
+
111
+ self.vision_model = self.initialize_submodule_config(
112
+ submodule_config=vision_model, batch_size=1, force_kwargs=True
113
+ )
114
+ self.qformer = self.initialize_submodule_config(submodule_config=qformer, batch_size=1, force_kwargs=True)
115
+ self.language_model = self.initialize_submodule_config(submodule_config=language_model)
@@ -30,38 +30,31 @@ from transformers.utils import logging
30
30
 
31
31
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
32
  from ....modeling import RBLNModel
33
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
34
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
33
35
 
34
36
 
35
37
  logger = logging.get_logger(__name__)
36
38
 
37
39
  if TYPE_CHECKING:
38
- from transformers import (
39
- AutoFeatureExtractor,
40
- AutoProcessor,
41
- AutoTokenizer,
42
- )
40
+ import rebel
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
43
42
 
44
43
 
45
- class LoopProjector:
46
- def __init__(self, language_projection) -> None:
47
- self.language_projection = language_projection
44
+ class LoopProjector(LoopProcessor):
45
+ def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
46
+ super().__init__(model=language_projection)
48
47
 
49
- def forward(self, *args, **kwargs):
50
- query_output = args[0]
48
+ def _get_batch_size(self, query_output, **kwargs):
49
+ return query_output.shape[0]
51
50
 
52
- batch_size = query_output.shape[0]
53
- outputs = []
54
- for i in range(batch_size):
55
- outputs.append(self.language_projection(query_output[i : i + 1]))
56
-
57
- outputs = torch.cat(outputs, dim=0)
58
- return outputs
59
-
60
- def __call__(self, *args: Any, **kwds: Any) -> Any:
61
- return self.forward(*args, **kwds)
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
52
+ query_output_item = query_output[index : index + 1]
53
+ return ([query_output_item], {})
62
54
 
63
- def __repr__(self) -> str:
64
- return repr(self.language_projection)
55
+ def _process_outputs(self, outputs: list, **kwargs):
56
+ output = torch.cat(outputs, dim=0)
57
+ return output
65
58
 
66
59
 
67
60
  class RBLNBlip2VisionModel(RBLNModel):
@@ -72,6 +65,8 @@ class RBLNBlip2VisionModel(RBLNModel):
72
65
  on RBLN devices, supporting image encoding for multimodal vision-language tasks.
73
66
  """
74
67
 
68
+ _tp_support = False
69
+
75
70
  def get_input_embeddings(self):
76
71
  return self.embeddings
77
72
 
@@ -100,8 +95,7 @@ class RBLNBlip2VisionModel(RBLNModel):
100
95
  (
101
96
  "pixel_values",
102
97
  [
103
- # support for vllm CB (prefill)
104
- 1,
98
+ rbln_config.batch_size,
105
99
  model_config.num_channels,
106
100
  model_config.image_size,
107
101
  model_config.image_size,
@@ -116,7 +110,7 @@ class RBLNBlip2VisionModel(RBLNModel):
116
110
 
117
111
  def forward(
118
112
  self,
119
- pixel_values,
113
+ pixel_values: torch.FloatTensor,
120
114
  output_attentions: Optional[bool] = None,
121
115
  output_hidden_states: Optional[bool] = None,
122
116
  return_dict: Optional[bool] = None,
@@ -151,6 +145,8 @@ class RBLNBlip2QFormerModel(RBLNModel):
151
145
  mechanisms for multimodal understanding tasks.
152
146
  """
153
147
 
148
+ _tp_support = False
149
+
154
150
  def get_input_embeddings(self):
155
151
  return self.embeddings.word_embeddings
156
152
 
@@ -178,7 +174,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
178
174
  return Blip2QFormerModelWrapper(model).eval()
179
175
 
180
176
  @classmethod
181
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
177
+ def _update_submodule_config(
178
+ cls,
179
+ model: "PreTrainedModel",
180
+ rbln_config: RBLNModelConfig,
181
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
182
+ ):
182
183
  if rbln_config.num_query_tokens is None:
183
184
  rbln_config.num_query_tokens = model.config.num_query_tokens
184
185
 
@@ -199,7 +200,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
199
200
  (
200
201
  "query_embeds",
201
202
  [
202
- 1,
203
+ rbln_config.batch_size,
203
204
  rbln_config.num_query_tokens,
204
205
  model_config.hidden_size,
205
206
  ],
@@ -208,7 +209,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
208
209
  (
209
210
  "encoder_hidden_states",
210
211
  [
211
- 1,
212
+ rbln_config.batch_size,
212
213
  # image_text_hidden_size + cls token
213
214
  rbln_config.image_text_hidden_size + 1,
214
215
  model_config.encoder_hidden_size,
@@ -218,7 +219,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
218
219
  (
219
220
  "encoder_attention_mask",
220
221
  # image_text_hidden_size + cls token
221
- [1, rbln_config.image_text_hidden_size + 1],
222
+ [rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
222
223
  "int64",
223
224
  ),
224
225
  ]
@@ -265,7 +266,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
265
266
  )
266
267
 
267
268
 
268
- class RBLNBlip2ForConditionalGeneration(RBLNModel):
269
+ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
269
270
  """
270
271
  RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
271
272
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -433,3 +434,66 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
433
434
  )
434
435
 
435
436
  return inputs_embeds
437
+
438
+ @torch.no_grad()
439
+ def generate(
440
+ self,
441
+ pixel_values: torch.FloatTensor,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
445
+ interpolate_pos_encoding: bool = False,
446
+ **generate_kwargs,
447
+ ) -> torch.LongTensor:
448
+ batch_size = pixel_values.shape[0]
449
+ image_embeds = self.vision_model(
450
+ pixel_values,
451
+ return_dict=True,
452
+ interpolate_pos_encoding=interpolate_pos_encoding,
453
+ ).last_hidden_state
454
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
455
+
456
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
457
+ query_outputs = self.qformer(
458
+ query_embeds=query_tokens,
459
+ encoder_hidden_states=image_embeds,
460
+ encoder_attention_mask=image_attention_mask,
461
+ return_dict=True,
462
+ )
463
+ query_output = query_outputs.last_hidden_state
464
+
465
+ if query_output.dtype != image_embeds.dtype:
466
+ query_output = query_output.to(image_embeds.dtype)
467
+
468
+ language_model_inputs = self.language_projection(query_output)
469
+
470
+ if inputs_embeds is None:
471
+ if input_ids is None:
472
+ image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
473
+ start_tokens = image_tokens + [self.config.text_config.bos_token_id]
474
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
475
+ input_ids = input_ids.repeat(batch_size, 1)
476
+ inputs_embeds = self.get_input_embeddings()(input_ids)
477
+
478
+ if attention_mask is None:
479
+ attention_mask = torch.ones_like(input_ids)
480
+
481
+ if input_ids is None:
482
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
483
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
484
+ )
485
+ special_image_mask = special_image_mask.all(-1)
486
+ else:
487
+ special_image_mask = input_ids == self.config.image_token_id
488
+
489
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
490
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
491
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
492
+
493
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
494
+ if not self.language_model.config.is_encoder_decoder:
495
+ inputs["input_ids"] = input_ids
496
+
497
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
498
+
499
+ return outputs