optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -23,9 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
23
23
  import rebel
24
24
  import torch
25
25
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
+ from transformers.utils.hub import PushToHubMixin
26
27
 
27
28
  from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
- from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
+ from .utils.hub import pull_compiled_model_from_hub, validate_files
29
30
  from .utils.logging import get_logger
30
31
  from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
31
32
  from .utils.save_utils import maybe_load_preprocessors
@@ -33,7 +34,7 @@ from .utils.submodule import SubModulesMixin
33
34
 
34
35
 
35
36
  if TYPE_CHECKING:
36
- from transformers import PreTrainedModel
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
37
38
 
38
39
  logger = get_logger(__name__)
39
40
 
@@ -50,11 +51,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
50
51
  model_type = "rbln_model"
51
52
  auto_model_class = AutoModel
52
53
  config_class = AutoConfig
53
-
54
54
  config_name = "config.json"
55
55
  hf_library_name = "transformers"
56
- _hf_class = None
57
- _rbln_config_class = None
56
+ _supports_non_fp32 = False
58
57
 
59
58
  def __init__(
60
59
  self,
@@ -72,7 +71,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
72
71
  self.rbln_config = rbln_config
73
72
  if not rbln_config.is_frozen():
74
73
  raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
75
-
76
74
  self.compiled_models = rbln_compiled_models
77
75
 
78
76
  # Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
@@ -93,7 +91,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
93
91
 
94
92
  self.device = torch.device("cpu")
95
93
  self.training = False
96
- self.dtype = torch.float32
94
+ self.dtype = rbln_config.torch_dtype
97
95
 
98
96
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
99
97
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -115,7 +113,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
115
113
  def _load_compiled_model_dir(
116
114
  cls,
117
115
  model_id: Union[str, Path],
118
- use_auth_token: Optional[Union[bool, str]] = None,
116
+ token: Optional[Union[bool, str]] = None,
119
117
  revision: Optional[str] = None,
120
118
  force_download: bool = False,
121
119
  cache_dir: Optional[str] = None,
@@ -134,7 +132,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
134
132
  model_path = pull_compiled_model_from_hub(
135
133
  model_id=model_id,
136
134
  subfolder=subfolder,
137
- use_auth_token=use_auth_token,
135
+ token=token,
138
136
  revision=revision,
139
137
  cache_dir=cache_dir,
140
138
  force_download=force_download,
@@ -172,7 +170,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
172
170
  cls,
173
171
  model_id: Union[str, Path],
174
172
  config: Optional["PretrainedConfig"] = None,
175
- use_auth_token: Optional[Union[bool, str]] = None,
173
+ token: Optional[Union[bool, str]] = None,
176
174
  revision: Optional[str] = None,
177
175
  force_download: bool = False,
178
176
  cache_dir: Optional[str] = None,
@@ -189,7 +187,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
189
187
  if rbln_compiled_models is None:
190
188
  model_path_subfolder = cls._load_compiled_model_dir(
191
189
  model_id=model_id,
192
- use_auth_token=use_auth_token,
190
+ token=token,
193
191
  revision=revision,
194
192
  force_download=force_download,
195
193
  cache_dir=cache_dir,
@@ -232,7 +230,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
232
230
  cache_dir=cache_dir,
233
231
  force_download=force_download,
234
232
  revision=revision,
235
- token=use_auth_token,
233
+ token=token,
236
234
  trust_remote_code=trust_remote_code,
237
235
  )
238
236
  elif cls.hf_library_name == "diffusers":
@@ -250,7 +248,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
250
248
  force_download=force_download,
251
249
  local_files_only=local_files_only,
252
250
  revision=revision,
253
- token=use_auth_token,
251
+ token=token,
254
252
  subfolder=subfolder,
255
253
  )
256
254
  config = PretrainedConfig(**config)
@@ -316,7 +314,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
316
314
  rbln_config,
317
315
  model_save_dir=model_save_dir,
318
316
  subfolder=subfolder,
319
- rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
317
+ rbln_compiled_models=rbln_compiled_models,
320
318
  rbln_submodules=rbln_submodules,
321
319
  **kwargs,
322
320
  )
@@ -344,32 +342,72 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
344
342
  rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
345
343
  return rbln_config, kwargs
346
344
 
345
+ @classmethod
346
+ def _is_compiled(
347
+ cls,
348
+ model_id: Union[str, Path],
349
+ token: Optional[Union[bool, str]] = None,
350
+ revision: Optional[str] = None,
351
+ force_download: bool = False,
352
+ cache_dir: Optional[str] = None,
353
+ subfolder: str = "",
354
+ local_files_only: bool = False,
355
+ ) -> bool:
356
+ # Check if the model is already compiled.
357
+ try:
358
+ cls._load_compiled_model_dir(
359
+ model_id=model_id,
360
+ token=token,
361
+ revision=revision,
362
+ force_download=force_download,
363
+ cache_dir=cache_dir,
364
+ subfolder=subfolder,
365
+ local_files_only=local_files_only,
366
+ )
367
+ return True
368
+ except (FileNotFoundError, KeyError):
369
+ return False
370
+
347
371
  @classmethod
348
372
  def from_pretrained(
349
373
  cls: Type["RBLNBaseModel"],
350
374
  model_id: Union[str, Path],
351
- export: bool = False,
375
+ export: Optional[bool] = None,
352
376
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
353
- **kwargs: Dict[str, Any],
377
+ **kwargs: Any,
354
378
  ) -> "RBLNBaseModel":
355
379
  """
356
380
  The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
357
381
  User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
358
382
 
359
383
  Args:
360
- model_id: The model id of the pre-trained model to be loaded. It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
361
- export: A boolean flag to indicate whether the model should be compiled.
362
- rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
384
+ model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
385
+ It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
386
+ export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
387
+ If None, it will be determined based on the existence of the compiled model files in the model_id.
388
+ rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
389
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
363
390
  For detailed configuration options, see the specific model's configuration class documentation.
364
-
365
- kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
391
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
366
392
 
367
393
  Returns:
368
- A RBLN model instance ready for inference on RBLN NPU devices.
394
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
369
395
  """
370
396
 
371
397
  if isinstance(model_id, Path):
372
398
  model_id = model_id.as_posix()
399
+
400
+ if export is None:
401
+ export = not cls._is_compiled(
402
+ model_id=model_id,
403
+ token=kwargs.get("token"),
404
+ revision=kwargs.get("revision"),
405
+ force_download=kwargs.get("force_download", False),
406
+ cache_dir=kwargs.get("cache_dir"),
407
+ subfolder=kwargs.get("subfolder", ""),
408
+ local_files_only=kwargs.get("local_files_only", False),
409
+ )
410
+
373
411
  from_pretrained_method = cls._export if export else cls._from_pretrained
374
412
  return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
375
413
 
@@ -394,7 +432,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
394
432
  compiled_model = rebel.compile_from_torch(
395
433
  model,
396
434
  input_info=rbln_compile_config.input_info,
397
- fusion=rbln_compile_config.fusion,
398
435
  npu=rbln_compile_config.npu,
399
436
  tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
400
437
  **kwargs,
@@ -402,8 +439,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
402
439
  return compiled_model
403
440
 
404
441
  @classmethod
405
- def update_rbln_config(cls, **others) -> RBLNModelConfig:
406
- rbln_config = cls._update_rbln_config(**others)
442
+ def update_rbln_config(
443
+ cls,
444
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
445
+ model: "PreTrainedModel",
446
+ model_config: "PretrainedConfig",
447
+ rbln_config: RBLNModelConfig,
448
+ ) -> RBLNModelConfig:
449
+ rbln_config.torch_dtype = model.dtype
450
+ if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
451
+ raise NotImplementedError(
452
+ f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
453
+ )
454
+ rbln_config = cls._update_rbln_config(
455
+ preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
456
+ )
407
457
  rbln_config.freeze()
408
458
  if rbln_config.rbln_model_cls_name != cls.__name__:
409
459
  raise NameError(
@@ -421,7 +471,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
421
471
 
422
472
  # Returns:
423
473
  # type: The original HuggingFace model class
424
- if cls._hf_class is None:
474
+ if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
425
475
  hf_cls_name = cls.__name__[4:]
426
476
  library = importlib.import_module(cls.hf_library_name)
427
477
  cls._hf_class = getattr(library, hf_cls_name, None)
@@ -430,7 +480,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
430
480
  @classmethod
431
481
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
432
482
  # Lazily loads and caches the corresponding RBLN model config class.
433
- if cls._rbln_config_class is None:
483
+ if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
434
484
  rbln_config_class_name = cls.__name__ + "Config"
435
485
  cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
436
486
  return cls._rbln_config_class
@@ -446,12 +496,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
446
496
 
447
497
  # This method mimics the interface of torch.nn.Module.parameters()
448
498
  # specifically for code that uses `next(model.parameters())` to infer
449
- # the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
499
+ # the device or dtype. It yields a single dummy tensor on CPU with model dtype.
450
500
 
451
501
  # Warning:
452
502
  # This does NOT yield the actual model parameters used by the RBLN runtime.
453
503
  # Code relying on iterating through all model parameters will not work as expected.
454
- yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
504
+ yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
455
505
 
456
506
  def __call__(self, *args, **kwargs):
457
507
  return self.forward(*args, **kwargs)
@@ -486,9 +536,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
486
536
  [`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
487
537
 
488
538
  Args:
489
- save_directory (`Union[str, Path]`):
539
+ save_directory (Union[str, Path]):
490
540
  Directory where to save the model file.
491
- push_to_hub (`bool`, *optional*, defaults to `False`):
541
+ push_to_hub (bool):
492
542
  Whether or not to push your model to the HuggingFace model hub after saving it.
493
543
 
494
544
  """
@@ -507,6 +557,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
507
557
  f"Please ensure the model directory exists and you have the necessary permissions to access it."
508
558
  )
509
559
 
560
+ if isinstance(self.config, PretrainedConfig):
561
+ self.config.save_pretrained(real_save_dir)
562
+
510
563
  if save_directory_path == real_save_dir:
511
564
  raise FileExistsError(
512
565
  f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
@@ -522,10 +575,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
522
575
  # First copy everything to a temporary directory
523
576
  shutil.copytree(real_save_dir, tmp_dir)
524
577
 
525
- # If everything succeeded, atomically replace the target directory
578
+ # If everything succeeded, move files to target directory
526
579
  if os.path.exists(save_directory_path):
527
- shutil.rmtree(save_directory_path)
528
- os.rename(tmp_dir, save_directory_path)
580
+ # Merge files from tmp_dir into existing directory
581
+ def _merge_dir(src_root: str, dst_root: str):
582
+ for name in os.listdir(src_root):
583
+ src_item = os.path.join(src_root, name)
584
+ dst_item = os.path.join(dst_root, name)
585
+
586
+ if os.path.islink(src_item) or os.path.isfile(src_item):
587
+ os.makedirs(os.path.dirname(dst_item), exist_ok=True)
588
+ if os.path.isdir(dst_item) and not os.path.islink(dst_item):
589
+ shutil.rmtree(dst_item)
590
+ os.replace(src_item, dst_item)
591
+ elif os.path.isdir(src_item):
592
+ if os.path.islink(dst_item) or os.path.isfile(dst_item):
593
+ os.remove(dst_item)
594
+ os.makedirs(dst_item, exist_ok=True)
595
+ _merge_dir(src_item, dst_item)
596
+ else:
597
+ # Fallback for special file types
598
+ os.replace(src_item, dst_item)
599
+
600
+ _merge_dir(tmp_dir, str(save_directory_path))
601
+
602
+ # Remove the temporary directory tree after merge
603
+ shutil.rmtree(tmp_dir)
604
+ else:
605
+ # If target doesn't exist, just rename tmp_dir to target
606
+ os.rename(tmp_dir, save_directory_path)
529
607
 
530
608
  except Exception as e:
531
609
  # Clean up the temporary directory if anything fails
@@ -534,7 +612,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
534
612
  raise e # Re-raise the exception after cleanup
535
613
 
536
614
  if push_to_hub:
537
- return super().push_to_hub(str(save_directory_path), **kwargs)
615
+ repo_id = kwargs.pop("repo_id", None)
616
+ if repo_id is None:
617
+ raise ValueError("`repo_id` must be provided to push the model to the HuggingFace model hub.")
618
+ return super().push_to_hub(repo_id=repo_id, **kwargs)
538
619
 
539
620
  @staticmethod
540
621
  def _raise_missing_compiled_file_error(missing_files: List[str]):
optimum/rbln/ops/attn.py CHANGED
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
53
53
  return torch.empty_like(q)
54
54
 
55
55
 
56
+ @torch.library.custom_op(
57
+ "rbln_custom_ops::paged_attn_decode_kv_fp8",
58
+ mutates_args=(["kcache", "vcache"]),
59
+ )
60
+ def paged_attn_decode_kv_fp8(
61
+ q: Tensor,
62
+ k: Tensor,
63
+ v: Tensor,
64
+ mask: Tensor,
65
+ kcache: Tensor,
66
+ vcache: Tensor,
67
+ seq: Tensor,
68
+ scale: Tensor,
69
+ block_table: Tensor,
70
+ block_size: int,
71
+ k_scale: Tensor,
72
+ v_scale: Tensor,
73
+ ) -> Tensor:
74
+ return torch.empty_like(q)
75
+
76
+
77
+ @paged_attn_decode_kv_fp8.register_fake
78
+ def paged_attn_decode_kv_fp8_fake(
79
+ q: Tensor,
80
+ k: Tensor,
81
+ v: Tensor,
82
+ mask: Tensor,
83
+ kcache: Tensor,
84
+ vcache: Tensor,
85
+ seq: Tensor,
86
+ scale: Tensor,
87
+ block_table: Tensor,
88
+ block_size: int,
89
+ k_scale: Tensor,
90
+ v_scale: Tensor,
91
+ ) -> Tensor:
92
+ return torch.empty_like(q)
93
+
94
+
56
95
  @torch.library.custom_op(
57
96
  "rbln_custom_ops::paged_attn_prefill",
58
97
  mutates_args=(["kcache", "vcache"]),
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
112
151
  return torch.empty_like(q)
113
152
 
114
153
 
154
+ @torch.library.custom_op(
155
+ "rbln_custom_ops::paged_attn_prefill_kv_fp8",
156
+ mutates_args=(["kcache", "vcache"]),
157
+ )
158
+ def paged_attn_prefill_kv_fp8(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ mask: Tensor,
163
+ kcache: Tensor,
164
+ vcache: Tensor,
165
+ seq: Tensor,
166
+ scale: Tensor,
167
+ block_table: Tensor,
168
+ block_size: int,
169
+ k_scale: Tensor,
170
+ v_scale: Tensor,
171
+ ) -> Tensor:
172
+ return torch.empty_like(q)
173
+
174
+
175
+ @paged_attn_prefill_kv_fp8.register_fake
176
+ def paged_attn_prefill_kv_fp8_fake(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ mask: Tensor,
181
+ kcache: Tensor,
182
+ vcache: Tensor,
183
+ seq: Tensor,
184
+ scale: Tensor,
185
+ block_table: Tensor,
186
+ block_size: int,
187
+ k_scale: Tensor,
188
+ v_scale: Tensor,
189
+ ) -> Tensor:
190
+ return torch.empty_like(q)
191
+
192
+
115
193
  @torch.library.custom_op(
116
194
  "rbln_custom_ops::paged_causal_attn_decode",
117
195
  mutates_args=(["kcache", "vcache"]),
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
236
314
  return torch.empty_like(q)
237
315
 
238
316
 
317
+ @torch.library.custom_op(
318
+ "rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
319
+ mutates_args=(["kcache", "vcache"]),
320
+ )
321
+ def paged_causal_attn_decode_kv_fp8(
322
+ q: Tensor,
323
+ k: Tensor,
324
+ v: Tensor,
325
+ kcache: Tensor,
326
+ vcache: Tensor,
327
+ seq: Tensor,
328
+ scale: Tensor,
329
+ block_table: Tensor,
330
+ block_size: int,
331
+ k_scale: Tensor,
332
+ v_scale: Tensor,
333
+ mask: Optional[Tensor] = None,
334
+ ) -> Tensor:
335
+ return torch.empty_like(q)
336
+
337
+
338
+ @paged_causal_attn_decode_kv_fp8.register_fake
339
+ def paged_causal_attn_decode_kv_fp8_fake(
340
+ q: Tensor,
341
+ k: Tensor,
342
+ v: Tensor,
343
+ kcache: Tensor,
344
+ vcache: Tensor,
345
+ seq: Tensor,
346
+ scale: Tensor,
347
+ block_table: Tensor,
348
+ block_size: int,
349
+ k_scale: Tensor,
350
+ v_scale: Tensor,
351
+ mask: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ return torch.empty_like(q)
354
+
355
+
356
+ @torch.library.custom_op(
357
+ "rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
358
+ mutates_args=(["kcache", "vcache"]),
359
+ )
360
+ def paged_causal_attn_prefill_kv_fp8(
361
+ q: Tensor,
362
+ k: Tensor,
363
+ v: Tensor,
364
+ kcache: Tensor,
365
+ vcache: Tensor,
366
+ seq: Tensor,
367
+ scale: Tensor,
368
+ block_table: Tensor,
369
+ block_size: int,
370
+ is_bidirectional: bool,
371
+ k_scale: Tensor,
372
+ v_scale: Tensor,
373
+ mask: Optional[Tensor] = None,
374
+ ) -> Tensor:
375
+ return torch.empty_like(q)
376
+
377
+
378
+ @paged_causal_attn_prefill_kv_fp8.register_fake
379
+ def paged_causal_attn_prefill_kv_fp8_fake(
380
+ q: Tensor,
381
+ k: Tensor,
382
+ v: Tensor,
383
+ kcache: Tensor,
384
+ vcache: Tensor,
385
+ seq: Tensor,
386
+ scale: Tensor,
387
+ block_table: Tensor,
388
+ block_size: int,
389
+ is_bidirectional: bool,
390
+ k_scale: Tensor,
391
+ v_scale: Tensor,
392
+ mask: Optional[Tensor] = None,
393
+ ) -> Tensor:
394
+ return torch.empty_like(q)
395
+
396
+
239
397
  @torch.library.custom_op(
240
398
  "rbln_custom_ops::paged_add_softmax_attn_decode",
241
399
  mutates_args=(["kcache", "vcache"]),
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
59
59
  return torch.empty_like(q)
60
60
 
61
61
 
62
+ @torch.library.custom_op(
63
+ "rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
64
+ mutates_args=(["kcache", "vcache"]),
65
+ )
66
+ def paged_flash_attn_decode_kv_fp8(
67
+ q: Tensor,
68
+ k: Tensor,
69
+ v: Tensor,
70
+ mask: Tensor,
71
+ kcache: Tensor,
72
+ vcache: Tensor,
73
+ seq: Tensor,
74
+ scale: Tensor,
75
+ block_table: Tensor,
76
+ block_size: int,
77
+ partition: int,
78
+ k_scale: Tensor,
79
+ v_scale: Tensor,
80
+ ) -> Tensor:
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_decode_kv_fp8.register_fake
85
+ def paged_flash_attn_decode_kv_fp8_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ k_scale: Tensor,
98
+ v_scale: Tensor,
99
+ ) -> Tensor:
100
+ return torch.empty_like(q)
101
+
102
+
62
103
  @torch.library.custom_op(
63
104
  "rbln_custom_ops::paged_flash_attn_prefill",
64
105
  mutates_args=(["kcache", "vcache"]),
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
100
141
  return torch.empty_like(q)
101
142
 
102
143
 
144
+ @torch.library.custom_op(
145
+ "rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
146
+ mutates_args=(["kcache", "vcache"]),
147
+ )
148
+ def paged_flash_attn_prefill_kv_fp8(
149
+ q: Tensor,
150
+ k: Tensor,
151
+ v: Tensor,
152
+ mask: Tensor,
153
+ kcache: Tensor,
154
+ vcache: Tensor,
155
+ seq: Tensor,
156
+ scale: Tensor,
157
+ block_table: Tensor,
158
+ block_size: int,
159
+ partition: int,
160
+ k_scale: Tensor,
161
+ v_scale: Tensor,
162
+ ) -> Tensor:
163
+ return torch.empty_like(q)
164
+
165
+
166
+ @paged_flash_attn_prefill_kv_fp8.register_fake
167
+ def paged_flash_attn_prefill_kv_fp8_fake(
168
+ q: Tensor,
169
+ k: Tensor,
170
+ v: Tensor,
171
+ mask: Tensor,
172
+ kcache: Tensor,
173
+ vcache: Tensor,
174
+ seq: Tensor,
175
+ scale: Tensor,
176
+ block_table: Tensor,
177
+ block_size: int,
178
+ partition: int,
179
+ k_scale: Tensor,
180
+ v_scale: Tensor,
181
+ ) -> Tensor:
182
+ return torch.empty_like(q)
183
+
184
+
103
185
  @torch.library.custom_op(
104
186
  "rbln_custom_ops::paged_flash_causal_attn_decode",
105
187
  mutates_args=(["kcache", "vcache"]),
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
141
223
  return torch.empty_like(q)
142
224
 
143
225
 
226
+ @torch.library.custom_op(
227
+ "rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
228
+ mutates_args=(["kcache", "vcache"]),
229
+ )
230
+ def paged_flash_causal_attn_decode_kv_fp8(
231
+ q: Tensor,
232
+ k: Tensor,
233
+ v: Tensor,
234
+ kcache: Tensor,
235
+ vcache: Tensor,
236
+ seq: Tensor,
237
+ scale: Tensor,
238
+ block_table: Tensor,
239
+ block_size: int,
240
+ partition: int,
241
+ k_scale: Tensor,
242
+ v_scale: Tensor,
243
+ mask: Optional[Tensor] = None,
244
+ ) -> Tensor:
245
+ return torch.empty_like(q)
246
+
247
+
248
+ @paged_flash_causal_attn_decode_kv_fp8.register_fake
249
+ def paged_flash_causal_attn_decode_kv_fp8_fake(
250
+ q: Tensor,
251
+ k: Tensor,
252
+ v: Tensor,
253
+ kcache: Tensor,
254
+ vcache: Tensor,
255
+ seq: Tensor,
256
+ scale: Tensor,
257
+ block_table: Tensor,
258
+ block_size: int,
259
+ partition: int,
260
+ k_scale: Tensor,
261
+ v_scale: Tensor,
262
+ mask: Optional[Tensor] = None,
263
+ ) -> Tensor:
264
+ return torch.empty_like(q)
265
+
266
+
144
267
  @torch.library.custom_op(
145
268
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
146
269
  mutates_args=(["kcache", "vcache"]),
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
182
305
  mask: Optional[Tensor] = None,
183
306
  ) -> Tensor:
184
307
  return torch.empty_like(q)
308
+
309
+
310
+ @torch.library.custom_op(
311
+ "rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
312
+ mutates_args=(["kcache", "vcache"]),
313
+ )
314
+ def paged_flash_causal_attn_prefill_kv_fp8(
315
+ q: Tensor,
316
+ k: Tensor,
317
+ v: Tensor,
318
+ kcache: Tensor,
319
+ vcache: Tensor,
320
+ seq: Tensor,
321
+ scale: Tensor,
322
+ block_table: Tensor,
323
+ block_size: int,
324
+ partition: int,
325
+ is_bidirectional: bool,
326
+ k_scale: Tensor,
327
+ v_scale: Tensor,
328
+ mask: Optional[Tensor] = None,
329
+ ) -> Tensor:
330
+ return torch.empty_like(q)
331
+
332
+
333
+ @paged_flash_causal_attn_prefill_kv_fp8.register_fake
334
+ def paged_flash_causal_attn_prefill_kv_fp8_fake(
335
+ q: Tensor,
336
+ k: Tensor,
337
+ v: Tensor,
338
+ kcache: Tensor,
339
+ vcache: Tensor,
340
+ seq: Tensor,
341
+ scale: Tensor,
342
+ block_table: Tensor,
343
+ block_size: int,
344
+ partition: int,
345
+ is_bidirectional: bool,
346
+ k_scale: Tensor,
347
+ v_scale: Tensor,
348
+ mask: Optional[Tensor] = None,
349
+ ) -> Tensor:
350
+ return torch.empty_like(q)