optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -33,9 +33,9 @@ if is_cosmos_guardrail_available():
33
33
  from cosmos_guardrail import CosmosSafetyChecker
34
34
  from cosmos_guardrail.cosmos_guardrail import (
35
35
  COSMOS_GUARDRAIL_CHECKPOINT,
36
- Aegis,
37
36
  Blocklist,
38
37
  GuardrailRunner,
38
+ LlamaGuard3,
39
39
  ModelConfig,
40
40
  RetinaFaceFilter,
41
41
  SafetyClassifier,
@@ -55,7 +55,7 @@ else:
55
55
 
56
56
  COSMOS_GUARDRAIL_CHECKPOINT = None
57
57
 
58
- class Aegis(FailToImportCosmosGuardrail): ...
58
+ class LlamaGuard3(FailToImportCosmosGuardrail): ...
59
59
 
60
60
  class Blocklist(FailToImportCosmosGuardrail): ...
61
61
 
@@ -312,33 +312,31 @@ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
312
312
  self.encoder.save_pretrained(checkpoint_id)
313
313
 
314
314
 
315
- class RBLNAegis(Aegis):
315
+ class RBLNLlamaGuard3(LlamaGuard3):
316
316
  def __init__(
317
317
  self,
318
318
  checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
319
- base_model_id: str = "meta-llama/LlamaGuard-7b",
320
- aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
319
+ base_model_id: str = "meta-llama/Llama-Guard-3-8B",
321
320
  rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
322
321
  ) -> None:
323
322
  if is_compiled_dir(checkpoint_id):
324
323
  torch.nn.Module.__init__(self)
325
- cache_dir = pathlib.Path(checkpoint_id) / "aegis"
324
+ cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
326
325
  self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
327
- self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.aegis)
326
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.llamaguard3)
328
327
 
329
328
  else:
330
- super().__init__(checkpoint_id, base_model_id, aegis_adapter)
331
- model = self.model.merge_and_unload() # peft merge
329
+ super().__init__(checkpoint_id, base_model_id)
330
+ model = self.model
332
331
  del self.model
333
-
334
- self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
332
+ self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.llamaguard3)
335
333
 
336
334
  self.rbln_config = rbln_config
337
335
  self.dtype = torch.bfloat16
338
336
  self.device = torch.device("cpu")
339
337
 
340
338
  def save_pretrained(self, checkpoint_id: str):
341
- cache_dir = pathlib.Path(checkpoint_id) / "aegis"
339
+ cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
342
340
  self.model.save_pretrained(cache_dir)
343
341
  self.tokenizer.save_pretrained(cache_dir)
344
342
 
@@ -351,8 +349,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
351
349
  def __init__(
352
350
  self,
353
351
  checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
354
- aegis_model_id: str = "meta-llama/LlamaGuard-7b",
355
- aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
352
+ llamaguard_model_id: str = "meta-llama/Llama-Guard-3-8B",
356
353
  rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
357
354
  ) -> None:
358
355
  torch.nn.Module.__init__(self)
@@ -369,10 +366,9 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
369
366
  self.text_guardrail = GuardrailRunner(
370
367
  safety_models=[
371
368
  Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
372
- RBLNAegis(
369
+ RBLNLlamaGuard3(
373
370
  checkpoint_id=checkpoint_id,
374
- base_model_id=aegis_model_id,
375
- aegis_adapter=aegis_adapter_id,
371
+ base_model_id=llamaguard_model_id,
376
372
  rbln_config=rbln_config,
377
373
  ),
378
374
  ]
@@ -387,7 +383,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
387
383
 
388
384
  def save_pretrained(self, save_dir: str):
389
385
  for text_safety_models in self.text_guardrail.safety_models:
390
- if isinstance(text_safety_models, RBLNAegis):
386
+ if isinstance(text_safety_models, RBLNLlamaGuard3):
391
387
  text_safety_models.save_pretrained(save_dir)
392
388
 
393
389
  for video_safety_models in self.video_guardrail.safety_models:
@@ -87,8 +87,38 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
+ """
93
+ Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
94
+
95
+ This method has two distinct operating modes:
96
+ - When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
97
+ - When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
98
+
99
+ It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
100
+
101
+ Args:
102
+ model_id (`str`):
103
+ The model ID or path to the pretrained model to load. Can be either:
104
+
105
+ - A model ID from the HuggingFace Hub
106
+ - A local path to a saved model directory
107
+ export:
108
+ If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
109
+ If False, loads an already compiled RBLN model from `model_id` without recompilation.
110
+ safety_checker:
111
+ Optional custom safety checker to use instead of the default one. Only used when `export=True`.
112
+ rbln_config:
113
+ Configuration options for RBLN compilation. Can include settings for specific submodules
114
+ such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
115
+ pipeline being compiled.
116
+ kwargs:
117
+ Additional arguments to pass to the underlying diffusion pipeline constructor or the
118
+ RBLN compilation process. These may include parameters specific to individual submodules
119
+ or the particular diffusion pipeline being used.
120
+ """
121
+
92
122
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
123
  if safety_checker is None and export:
94
124
  safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
@@ -87,8 +87,38 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
+ """
93
+ Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
94
+
95
+ This method has two distinct operating modes:
96
+ - When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
97
+ - When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
98
+
99
+ It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
100
+
101
+ Args:
102
+ model_id (`str`):
103
+ The model ID or path to the pretrained model to load. Can be either:
104
+
105
+ - A model ID from the HuggingFace Hub
106
+ - A local path to a saved model directory
107
+ export:
108
+ If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
109
+ If False, loads an already compiled RBLN model from `model_id` without recompilation.
110
+ safety_checker:
111
+ Optional custom safety checker to use instead of the default one. Only used when `export=True`.
112
+ rbln_config:
113
+ Configuration options for RBLN compilation. Can include settings for specific submodules
114
+ such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
115
+ pipeline being compiled.
116
+ kwargs:
117
+ Additional arguments to pass to the underlying diffusion pipeline constructor or the
118
+ RBLN compilation process. These may include parameters specific to individual submodules
119
+ or the particular diffusion pipeline being used.
120
+ """
121
+
92
122
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
123
  if safety_checker is None and export:
94
124
  safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
@@ -22,12 +22,7 @@ from diffusers import (
22
22
  UNet2DConditionModel,
23
23
  VQModel,
24
24
  )
25
- from transformers import (
26
- CLIPImageProcessor,
27
- CLIPTextModelWithProjection,
28
- CLIPTokenizer,
29
- CLIPVisionModelWithProjection,
30
- )
25
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
31
26
 
32
27
  from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
33
28
  from ...modeling_diffusers import RBLNDiffusionMixin
optimum/rbln/modeling.py CHANGED
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, ge
19
19
  import rebel
20
20
  import torch
21
21
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
22
- from transformers import AutoConfig, PretrainedConfig
22
+ from transformers import PretrainedConfig
23
23
  from transformers.modeling_outputs import BaseModelOutput
24
24
 
25
25
  from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
@@ -34,6 +34,49 @@ if TYPE_CHECKING:
34
34
  logger = get_logger(__name__)
35
35
 
36
36
 
37
+ def _get_dtype(
38
+ cls,
39
+ dtype: Optional[Union[str, torch.dtype, dict]],
40
+ config: PretrainedConfig,
41
+ ) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
42
+ dtype_orig = None
43
+
44
+ if dtype is not None:
45
+ if isinstance(dtype, str):
46
+ if dtype == "auto":
47
+ if hasattr(config, "dtype") and config.dtype is not None:
48
+ dtype = config.dtype
49
+ else:
50
+ dtype = torch.get_default_dtype()
51
+ elif hasattr(torch, dtype):
52
+ dtype = getattr(torch, dtype)
53
+ config.dtype = dtype
54
+ elif isinstance(dtype, torch.dtype):
55
+ config.dtype = dtype
56
+ elif isinstance(dtype, dict):
57
+ for key, curr_dtype in dtype.items():
58
+ if hasattr(config, key):
59
+ value = getattr(config, key)
60
+ curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
61
+ value.dtype = curr_dtype
62
+ # main torch dtype for modules that aren't part of any sub-config
63
+ dtype = dtype.get("")
64
+ dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
65
+ config.dtype = dtype
66
+ if dtype is None:
67
+ dtype = torch.float32
68
+ else:
69
+ raise ValueError(f"Invalid dtype: {dtype}")
70
+
71
+ dtype_orig = cls._set_default_dtype(dtype)
72
+ else:
73
+ # Use default dtype
74
+ default_dtype = torch.get_default_dtype()
75
+ config.dtype = default_dtype
76
+
77
+ return config, dtype, dtype_orig
78
+
79
+
37
80
  class RBLNModel(RBLNBaseModel):
38
81
  @classmethod
39
82
  def update_kwargs(cls, kwargs):
@@ -70,6 +113,10 @@ class RBLNModel(RBLNBaseModel):
70
113
  )
71
114
  return compiled_model
72
115
 
116
+ @classmethod
117
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
118
+ return model
119
+
73
120
  @classmethod
74
121
  def from_model(
75
122
  cls,
@@ -78,18 +125,20 @@ class RBLNModel(RBLNBaseModel):
78
125
  rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
79
126
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
80
127
  subfolder: str = "",
81
- **kwargs: Dict[str, Any],
128
+ **kwargs: Any,
82
129
  ) -> "RBLNModel":
83
130
  """
84
131
  Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
85
132
  This method performs the actual model conversion and compilation process.
86
133
 
87
134
  Args:
88
- model: The PyTorch model to be compiled. The object must be an instance of the HuggingFace transformers PreTrainedModel class.
89
- rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
135
+ model (PreTrainedModel): The PyTorch model to be compiled.
136
+ The object must be an instance of the HuggingFace transformers PreTrainedModel class.
137
+ config (Optional[PretrainedConfig]): The configuration object associated with the model.
138
+ rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
139
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
90
140
  For detailed configuration options, see the specific model's configuration class documentation.
91
-
92
- kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
141
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
93
142
 
94
143
  The method performs the following steps:
95
144
 
@@ -99,8 +148,10 @@ class RBLNModel(RBLNBaseModel):
99
148
  4. Saves the compiled model and configurations
100
149
 
101
150
  Returns:
102
- A RBLN model instance ready for inference on RBLN NPU devices.
151
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
103
152
  """
153
+
154
+ model = cls._reconstruct_model_if_needed(model)
104
155
  preprocessors = kwargs.pop("preprocessors", [])
105
156
  rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
106
157
 
@@ -119,9 +170,6 @@ class RBLNModel(RBLNBaseModel):
119
170
  # Save configs
120
171
  if config is None:
121
172
  config = model.config
122
- # remote_config
123
- if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
124
- config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
125
173
 
126
174
  if hasattr(model, "can_generate") and model.can_generate():
127
175
  import json
@@ -147,6 +195,7 @@ class RBLNModel(RBLNBaseModel):
147
195
  model=model,
148
196
  model_save_dir=save_dir,
149
197
  rbln_config=rbln_config,
198
+ preprocessors=preprocessors,
150
199
  **kwargs,
151
200
  )
152
201
  else:
@@ -209,6 +258,7 @@ class RBLNModel(RBLNBaseModel):
209
258
  **kwargs,
210
259
  ) -> "PreTrainedModel":
211
260
  kwargs = cls.update_kwargs(kwargs)
261
+
212
262
  return cls.get_hf_class().from_pretrained(
213
263
  model_id,
214
264
  subfolder=subfolder,
@@ -241,31 +291,33 @@ class RBLNModel(RBLNBaseModel):
241
291
  for compiled_model in compiled_models
242
292
  ]
243
293
 
244
- def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
294
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
245
295
  """
246
- Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
296
+ Defines the forward pass of `RBLNModel`. The interface mirrors HuggingFace conventions so it can act as a drop-in
297
+ replacement in many cases.
247
298
 
248
- This method executes the compiled RBLN model on RBLN NPU devices while maintaining full compatibility
249
- with HuggingFace transformers and diffusers APIs. The RBLNModel can be used as a direct substitute
250
- for any HuggingFace nn.Module/PreTrainedModel, enabling seamless integration into existing workflows.
299
+ This method executes the compiled RBLN model on RBLN NPU devices while remaining fully compatible with Hugging Face
300
+ Transformers and Diffusers APIs. In practice, `RBLNModel` can replace models built on `torch.nn.Module` — including
301
+ `transformers.PreTrainedModel` implementations and Diffusers components based on `diffusers.ModelMixin` enabling
302
+ seamless integration into existing workflows.
251
303
 
252
304
  Args:
253
- *args: Variable length argument list containing model inputs. The format matches the original
305
+ args: Variable length argument list containing model inputs. The format matches the original
254
306
  HuggingFace model's forward method signature (e.g., input_ids, attention_mask for
255
307
  transformers models, or sample, timestep for diffusers models).
256
308
  return_dict:
257
309
  Whether to return outputs as a dictionary-like object or as a tuple. When `None`:
258
310
  - For transformers models: Uses `self.config.use_return_dict` (typically `True`)
259
311
  - For diffusers models: Defaults to `True`
260
- **kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
312
+ kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
261
313
  matching the original HuggingFace model's interface.
262
314
 
263
315
  Returns:
264
316
  Model outputs in the same format as the original HuggingFace model.
265
317
 
266
- - If `return_dict=True`: Returns a dictionary-like object (e.g., BaseModelOutput,
318
+ If `return_dict=True`, Returns a dictionary-like object (e.g., BaseModelOutput,
267
319
  CausalLMOutput) with named fields such as `logits`, `hidden_states`, etc.
268
- - If `return_dict=False`: Returns a tuple containing the raw model outputs.
320
+ If `return_dict=False`, Returns a tuple containing the raw model outputs.
269
321
 
270
322
  Note:
271
323
  - This method maintains the exact same interface as the original HuggingFace model's forward method
@@ -34,7 +34,7 @@ from .utils.submodule import SubModulesMixin
34
34
 
35
35
 
36
36
  if TYPE_CHECKING:
37
- from transformers import PreTrainedModel
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
38
38
 
39
39
  logger = get_logger(__name__)
40
40
 
@@ -53,6 +53,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
53
53
  config_class = AutoConfig
54
54
  config_name = "config.json"
55
55
  hf_library_name = "transformers"
56
+ _supports_non_fp32 = False
56
57
 
57
58
  def __init__(
58
59
  self,
@@ -91,7 +92,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
91
92
 
92
93
  self.device = torch.device("cpu")
93
94
  self.training = False
94
- self.dtype = torch.float32
95
+ self.dtype = rbln_config.torch_dtype
95
96
 
96
97
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
97
98
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -314,7 +315,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
314
315
  rbln_config,
315
316
  model_save_dir=model_save_dir,
316
317
  subfolder=subfolder,
317
- rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
318
+ rbln_compiled_models=rbln_compiled_models,
318
319
  rbln_submodules=rbln_submodules,
319
320
  **kwargs,
320
321
  )
@@ -342,32 +343,72 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
342
343
  rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
343
344
  return rbln_config, kwargs
344
345
 
346
+ @classmethod
347
+ def _is_compiled(
348
+ cls,
349
+ model_id: Union[str, Path],
350
+ token: Optional[Union[bool, str]] = None,
351
+ revision: Optional[str] = None,
352
+ force_download: bool = False,
353
+ cache_dir: Optional[str] = None,
354
+ subfolder: str = "",
355
+ local_files_only: bool = False,
356
+ ) -> bool:
357
+ # Check if the model is already compiled.
358
+ try:
359
+ cls._load_compiled_model_dir(
360
+ model_id=model_id,
361
+ token=token,
362
+ revision=revision,
363
+ force_download=force_download,
364
+ cache_dir=cache_dir,
365
+ subfolder=subfolder,
366
+ local_files_only=local_files_only,
367
+ )
368
+ return True
369
+ except (FileNotFoundError, KeyError):
370
+ return False
371
+
345
372
  @classmethod
346
373
  def from_pretrained(
347
374
  cls: Type["RBLNBaseModel"],
348
375
  model_id: Union[str, Path],
349
- export: bool = False,
376
+ export: Optional[bool] = None,
350
377
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
351
- **kwargs: Dict[str, Any],
378
+ **kwargs: Any,
352
379
  ) -> "RBLNBaseModel":
353
380
  """
354
381
  The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
355
382
  User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
356
383
 
357
384
  Args:
358
- model_id: The model id of the pre-trained model to be loaded. It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
359
- export: A boolean flag to indicate whether the model should be compiled.
360
- rbln_config: Configuration for RBLN model compilation and runtime. This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
385
+ model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
386
+ It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
387
+ export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
388
+ If None, it will be determined based on the existence of the compiled model files in the model_id.
389
+ rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
390
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
361
391
  For detailed configuration options, see the specific model's configuration class documentation.
362
-
363
- kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
392
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
364
393
 
365
394
  Returns:
366
- A RBLN model instance ready for inference on RBLN NPU devices.
395
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
367
396
  """
368
397
 
369
398
  if isinstance(model_id, Path):
370
399
  model_id = model_id.as_posix()
400
+
401
+ if export is None:
402
+ export = not cls._is_compiled(
403
+ model_id=model_id,
404
+ token=kwargs.get("token"),
405
+ revision=kwargs.get("revision"),
406
+ force_download=kwargs.get("force_download", False),
407
+ cache_dir=kwargs.get("cache_dir"),
408
+ subfolder=kwargs.get("subfolder", ""),
409
+ local_files_only=kwargs.get("local_files_only", False),
410
+ )
411
+
371
412
  from_pretrained_method = cls._export if export else cls._from_pretrained
372
413
  return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
373
414
 
@@ -392,7 +433,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
392
433
  compiled_model = rebel.compile_from_torch(
393
434
  model,
394
435
  input_info=rbln_compile_config.input_info,
395
- fusion=rbln_compile_config.fusion,
396
436
  npu=rbln_compile_config.npu,
397
437
  tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
398
438
  **kwargs,
@@ -400,8 +440,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
400
440
  return compiled_model
401
441
 
402
442
  @classmethod
403
- def update_rbln_config(cls, **others) -> RBLNModelConfig:
404
- rbln_config = cls._update_rbln_config(**others)
443
+ def update_rbln_config(
444
+ cls,
445
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
446
+ model: "PreTrainedModel",
447
+ model_config: "PretrainedConfig",
448
+ rbln_config: RBLNModelConfig,
449
+ ) -> RBLNModelConfig:
450
+ rbln_config.torch_dtype = model.dtype
451
+ if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
452
+ raise NotImplementedError(
453
+ f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
454
+ )
455
+ rbln_config = cls._update_rbln_config(
456
+ preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
457
+ )
405
458
  rbln_config.freeze()
406
459
  if rbln_config.rbln_model_cls_name != cls.__name__:
407
460
  raise NameError(
@@ -444,12 +497,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
444
497
 
445
498
  # This method mimics the interface of torch.nn.Module.parameters()
446
499
  # specifically for code that uses `next(model.parameters())` to infer
447
- # the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
500
+ # the device or dtype. It yields a single dummy tensor on CPU with model dtype.
448
501
 
449
502
  # Warning:
450
503
  # This does NOT yield the actual model parameters used by the RBLN runtime.
451
504
  # Code relying on iterating through all model parameters will not work as expected.
452
- yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
505
+ yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
453
506
 
454
507
  def __call__(self, *args, **kwargs):
455
508
  return self.forward(*args, **kwargs)
@@ -484,9 +537,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
484
537
  [`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
485
538
 
486
539
  Args:
487
- save_directory (`Union[str, Path]`):
540
+ save_directory (Union[str, Path]):
488
541
  Directory where to save the model file.
489
- push_to_hub (`bool`, *optional*, defaults to `False`):
542
+ push_to_hub (bool):
490
543
  Whether or not to push your model to the HuggingFace model hub after saving it.
491
544
 
492
545
  """
@@ -523,10 +576,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
523
576
  # First copy everything to a temporary directory
524
577
  shutil.copytree(real_save_dir, tmp_dir)
525
578
 
526
- # If everything succeeded, atomically replace the target directory
579
+ # If everything succeeded, move files to target directory
527
580
  if os.path.exists(save_directory_path):
528
- shutil.rmtree(save_directory_path)
529
- os.rename(tmp_dir, save_directory_path)
581
+ # Merge files from tmp_dir into existing directory
582
+ def _merge_dir(src_root: str, dst_root: str):
583
+ for name in os.listdir(src_root):
584
+ src_item = os.path.join(src_root, name)
585
+ dst_item = os.path.join(dst_root, name)
586
+
587
+ if os.path.islink(src_item) or os.path.isfile(src_item):
588
+ os.makedirs(os.path.dirname(dst_item), exist_ok=True)
589
+ if os.path.isdir(dst_item) and not os.path.islink(dst_item):
590
+ shutil.rmtree(dst_item)
591
+ os.replace(src_item, dst_item)
592
+ elif os.path.isdir(src_item):
593
+ if os.path.islink(dst_item) or os.path.isfile(dst_item):
594
+ os.remove(dst_item)
595
+ os.makedirs(dst_item, exist_ok=True)
596
+ _merge_dir(src_item, dst_item)
597
+ else:
598
+ # Fallback for special file types
599
+ os.replace(src_item, dst_item)
600
+
601
+ _merge_dir(tmp_dir, str(save_directory_path))
602
+
603
+ # Remove the temporary directory tree after merge
604
+ shutil.rmtree(tmp_dir)
605
+ else:
606
+ # If target doesn't exist, just rename tmp_dir to target
607
+ os.rename(tmp_dir, save_directory_path)
530
608
 
531
609
  except Exception as e:
532
610
  # Clean up the temporary directory if anything fails