optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -33,9 +33,9 @@ if is_cosmos_guardrail_available():
|
|
|
33
33
|
from cosmos_guardrail import CosmosSafetyChecker
|
|
34
34
|
from cosmos_guardrail.cosmos_guardrail import (
|
|
35
35
|
COSMOS_GUARDRAIL_CHECKPOINT,
|
|
36
|
-
Aegis,
|
|
37
36
|
Blocklist,
|
|
38
37
|
GuardrailRunner,
|
|
38
|
+
LlamaGuard3,
|
|
39
39
|
ModelConfig,
|
|
40
40
|
RetinaFaceFilter,
|
|
41
41
|
SafetyClassifier,
|
|
@@ -55,7 +55,7 @@ else:
|
|
|
55
55
|
|
|
56
56
|
COSMOS_GUARDRAIL_CHECKPOINT = None
|
|
57
57
|
|
|
58
|
-
class
|
|
58
|
+
class LlamaGuard3(FailToImportCosmosGuardrail): ...
|
|
59
59
|
|
|
60
60
|
class Blocklist(FailToImportCosmosGuardrail): ...
|
|
61
61
|
|
|
@@ -312,33 +312,31 @@ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
|
|
|
312
312
|
self.encoder.save_pretrained(checkpoint_id)
|
|
313
313
|
|
|
314
314
|
|
|
315
|
-
class
|
|
315
|
+
class RBLNLlamaGuard3(LlamaGuard3):
|
|
316
316
|
def __init__(
|
|
317
317
|
self,
|
|
318
318
|
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
|
319
|
-
base_model_id: str = "meta-llama/
|
|
320
|
-
aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
|
319
|
+
base_model_id: str = "meta-llama/Llama-Guard-3-8B",
|
|
321
320
|
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
|
322
321
|
) -> None:
|
|
323
322
|
if is_compiled_dir(checkpoint_id):
|
|
324
323
|
torch.nn.Module.__init__(self)
|
|
325
|
-
cache_dir = pathlib.Path(checkpoint_id) / "
|
|
324
|
+
cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
|
|
326
325
|
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
|
327
|
-
self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.
|
|
326
|
+
self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.llamaguard3)
|
|
328
327
|
|
|
329
328
|
else:
|
|
330
|
-
super().__init__(checkpoint_id, base_model_id
|
|
331
|
-
model = self.model
|
|
329
|
+
super().__init__(checkpoint_id, base_model_id)
|
|
330
|
+
model = self.model
|
|
332
331
|
del self.model
|
|
333
|
-
|
|
334
|
-
self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
|
|
332
|
+
self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.llamaguard3)
|
|
335
333
|
|
|
336
334
|
self.rbln_config = rbln_config
|
|
337
335
|
self.dtype = torch.bfloat16
|
|
338
336
|
self.device = torch.device("cpu")
|
|
339
337
|
|
|
340
338
|
def save_pretrained(self, checkpoint_id: str):
|
|
341
|
-
cache_dir = pathlib.Path(checkpoint_id) / "
|
|
339
|
+
cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
|
|
342
340
|
self.model.save_pretrained(cache_dir)
|
|
343
341
|
self.tokenizer.save_pretrained(cache_dir)
|
|
344
342
|
|
|
@@ -351,8 +349,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
351
349
|
def __init__(
|
|
352
350
|
self,
|
|
353
351
|
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
|
354
|
-
|
|
355
|
-
aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
|
352
|
+
llamaguard_model_id: str = "meta-llama/Llama-Guard-3-8B",
|
|
356
353
|
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
|
357
354
|
) -> None:
|
|
358
355
|
torch.nn.Module.__init__(self)
|
|
@@ -369,10 +366,9 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
369
366
|
self.text_guardrail = GuardrailRunner(
|
|
370
367
|
safety_models=[
|
|
371
368
|
Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
|
|
372
|
-
|
|
369
|
+
RBLNLlamaGuard3(
|
|
373
370
|
checkpoint_id=checkpoint_id,
|
|
374
|
-
base_model_id=
|
|
375
|
-
aegis_adapter=aegis_adapter_id,
|
|
371
|
+
base_model_id=llamaguard_model_id,
|
|
376
372
|
rbln_config=rbln_config,
|
|
377
373
|
),
|
|
378
374
|
]
|
|
@@ -387,7 +383,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
387
383
|
|
|
388
384
|
def save_pretrained(self, save_dir: str):
|
|
389
385
|
for text_safety_models in self.text_guardrail.safety_models:
|
|
390
|
-
if isinstance(text_safety_models,
|
|
386
|
+
if isinstance(text_safety_models, RBLNLlamaGuard3):
|
|
391
387
|
text_safety_models.save_pretrained(save_dir)
|
|
392
388
|
|
|
393
389
|
for video_safety_models in self.video_guardrail.safety_models:
|
|
@@ -87,8 +87,38 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
|
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
89
|
rbln_config: Dict[str, Any] = {},
|
|
90
|
-
**kwargs:
|
|
90
|
+
**kwargs: Any,
|
|
91
91
|
):
|
|
92
|
+
"""
|
|
93
|
+
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
|
94
|
+
|
|
95
|
+
This method has two distinct operating modes:
|
|
96
|
+
- When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
|
|
97
|
+
- When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
|
|
98
|
+
|
|
99
|
+
It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_id (`str`):
|
|
103
|
+
The model ID or path to the pretrained model to load. Can be either:
|
|
104
|
+
|
|
105
|
+
- A model ID from the HuggingFace Hub
|
|
106
|
+
- A local path to a saved model directory
|
|
107
|
+
export:
|
|
108
|
+
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
|
109
|
+
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
|
110
|
+
safety_checker:
|
|
111
|
+
Optional custom safety checker to use instead of the default one. Only used when `export=True`.
|
|
112
|
+
rbln_config:
|
|
113
|
+
Configuration options for RBLN compilation. Can include settings for specific submodules
|
|
114
|
+
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
|
115
|
+
pipeline being compiled.
|
|
116
|
+
kwargs:
|
|
117
|
+
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
|
118
|
+
RBLN compilation process. These may include parameters specific to individual submodules
|
|
119
|
+
or the particular diffusion pipeline being used.
|
|
120
|
+
"""
|
|
121
|
+
|
|
92
122
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
93
123
|
if safety_checker is None and export:
|
|
94
124
|
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
|
@@ -87,8 +87,38 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
|
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
89
|
rbln_config: Dict[str, Any] = {},
|
|
90
|
-
**kwargs:
|
|
90
|
+
**kwargs: Any,
|
|
91
91
|
):
|
|
92
|
+
"""
|
|
93
|
+
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
|
94
|
+
|
|
95
|
+
This method has two distinct operating modes:
|
|
96
|
+
- When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
|
|
97
|
+
- When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
|
|
98
|
+
|
|
99
|
+
It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_id (`str`):
|
|
103
|
+
The model ID or path to the pretrained model to load. Can be either:
|
|
104
|
+
|
|
105
|
+
- A model ID from the HuggingFace Hub
|
|
106
|
+
- A local path to a saved model directory
|
|
107
|
+
export:
|
|
108
|
+
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
|
109
|
+
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
|
110
|
+
safety_checker:
|
|
111
|
+
Optional custom safety checker to use instead of the default one. Only used when `export=True`.
|
|
112
|
+
rbln_config:
|
|
113
|
+
Configuration options for RBLN compilation. Can include settings for specific submodules
|
|
114
|
+
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
|
115
|
+
pipeline being compiled.
|
|
116
|
+
kwargs:
|
|
117
|
+
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
|
118
|
+
RBLN compilation process. These may include parameters specific to individual submodules
|
|
119
|
+
or the particular diffusion pipeline being used.
|
|
120
|
+
"""
|
|
121
|
+
|
|
92
122
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
93
123
|
if safety_checker is None and export:
|
|
94
124
|
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
|
@@ -22,12 +22,7 @@ from diffusers import (
|
|
|
22
22
|
UNet2DConditionModel,
|
|
23
23
|
VQModel,
|
|
24
24
|
)
|
|
25
|
-
from transformers import
|
|
26
|
-
CLIPImageProcessor,
|
|
27
|
-
CLIPTextModelWithProjection,
|
|
28
|
-
CLIPTokenizer,
|
|
29
|
-
CLIPVisionModelWithProjection,
|
|
30
|
-
)
|
|
25
|
+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
31
26
|
|
|
32
27
|
from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
|
|
33
28
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .pipeline_stable_video_diffusion import RBLNStableVideoDiffusionPipeline
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from diffusers import StableVideoDiffusionPipeline
|
|
17
|
+
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
from ...configurations import RBLNStableVideoDiffusionPipelineConfig
|
|
20
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RBLNStableVideoDiffusionPipeline(RBLNDiffusionMixin, StableVideoDiffusionPipeline):
|
|
27
|
+
"""
|
|
28
|
+
RBLN-accelerated implementation of Stable Video Diffusion pipeline for image-to-video generation.
|
|
29
|
+
|
|
30
|
+
This pipeline compiles Stable Video Diffusion models to run efficiently on RBLN NPUs, enabling high-performance
|
|
31
|
+
inference for generating videos from images with optimized memory usage and throughput.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
original_class = StableVideoDiffusionPipeline
|
|
35
|
+
_rbln_config_class = RBLNStableVideoDiffusionPipelineConfig
|
|
36
|
+
_submodules = ["image_encoder", "unet", "vae"]
|
|
37
|
+
|
|
38
|
+
def handle_additional_kwargs(self, **kwargs):
|
|
39
|
+
compiled_num_frames = self.unet.rbln_config.num_frames
|
|
40
|
+
if compiled_num_frames is not None:
|
|
41
|
+
kwargs["num_frames"] = compiled_num_frames
|
|
42
|
+
|
|
43
|
+
compiled_decode_chunk_size = self.vae.rbln_config.decode_chunk_size
|
|
44
|
+
if compiled_decode_chunk_size is not None:
|
|
45
|
+
kwargs["decode_chunk_size"] = compiled_decode_chunk_size
|
|
46
|
+
return kwargs
|
optimum/rbln/modeling.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, ge
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
21
21
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|
22
|
-
from transformers import
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput
|
|
24
24
|
|
|
25
25
|
from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
|
|
@@ -54,13 +54,16 @@ class RBLNModel(RBLNBaseModel):
|
|
|
54
54
|
pass
|
|
55
55
|
|
|
56
56
|
@classmethod
|
|
57
|
-
def
|
|
57
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
58
58
|
# Wrap the model if needed.
|
|
59
59
|
return model
|
|
60
60
|
|
|
61
61
|
@classmethod
|
|
62
62
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
63
|
-
|
|
63
|
+
if rbln_config._allow_no_compile_cfgs:
|
|
64
|
+
return {}
|
|
65
|
+
|
|
66
|
+
model = cls._wrap_model_if_needed(model, rbln_config)
|
|
64
67
|
rbln_compile_config = rbln_config.compile_cfgs[0]
|
|
65
68
|
compiled_model = cls.compile(
|
|
66
69
|
model,
|
|
@@ -70,6 +73,22 @@ class RBLNModel(RBLNBaseModel):
|
|
|
70
73
|
)
|
|
71
74
|
return compiled_model
|
|
72
75
|
|
|
76
|
+
@classmethod
|
|
77
|
+
def _update_rbln_config(
|
|
78
|
+
cls,
|
|
79
|
+
preprocessors: Optional[Any],
|
|
80
|
+
model: Optional["PreTrainedModel"] = None,
|
|
81
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
82
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
|
83
|
+
) -> RBLNModelConfig:
|
|
84
|
+
# Default implementation: return config as-is
|
|
85
|
+
# Subclasses should override to set compile_cfgs if needed
|
|
86
|
+
return rbln_config
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
90
|
+
return model
|
|
91
|
+
|
|
73
92
|
@classmethod
|
|
74
93
|
def from_model(
|
|
75
94
|
cls,
|
|
@@ -78,18 +97,20 @@ class RBLNModel(RBLNBaseModel):
|
|
|
78
97
|
rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
|
|
79
98
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
80
99
|
subfolder: str = "",
|
|
81
|
-
**kwargs:
|
|
100
|
+
**kwargs: Any,
|
|
82
101
|
) -> "RBLNModel":
|
|
83
102
|
"""
|
|
84
103
|
Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
|
|
85
104
|
This method performs the actual model conversion and compilation process.
|
|
86
105
|
|
|
87
106
|
Args:
|
|
88
|
-
model: The PyTorch model to be compiled.
|
|
89
|
-
|
|
107
|
+
model (PreTrainedModel): The PyTorch model to be compiled.
|
|
108
|
+
The object must be an instance of the HuggingFace transformers PreTrainedModel class.
|
|
109
|
+
config (Optional[PretrainedConfig]): The configuration object associated with the model.
|
|
110
|
+
rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
|
|
111
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
90
112
|
For detailed configuration options, see the specific model's configuration class documentation.
|
|
91
|
-
|
|
92
|
-
kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
113
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
93
114
|
|
|
94
115
|
The method performs the following steps:
|
|
95
116
|
|
|
@@ -99,8 +120,10 @@ class RBLNModel(RBLNBaseModel):
|
|
|
99
120
|
4. Saves the compiled model and configurations
|
|
100
121
|
|
|
101
122
|
Returns:
|
|
102
|
-
A RBLN model instance ready for inference on RBLN NPU devices.
|
|
123
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
103
124
|
"""
|
|
125
|
+
|
|
126
|
+
model = cls._reconstruct_model_if_needed(model)
|
|
104
127
|
preprocessors = kwargs.pop("preprocessors", [])
|
|
105
128
|
rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
|
|
106
129
|
|
|
@@ -119,9 +142,6 @@ class RBLNModel(RBLNBaseModel):
|
|
|
119
142
|
# Save configs
|
|
120
143
|
if config is None:
|
|
121
144
|
config = model.config
|
|
122
|
-
# remote_config
|
|
123
|
-
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
|
124
|
-
config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
|
|
125
145
|
|
|
126
146
|
if hasattr(model, "can_generate") and model.can_generate():
|
|
127
147
|
import json
|
|
@@ -147,6 +167,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
147
167
|
model=model,
|
|
148
168
|
model_save_dir=save_dir,
|
|
149
169
|
rbln_config=rbln_config,
|
|
170
|
+
preprocessors=preprocessors,
|
|
150
171
|
**kwargs,
|
|
151
172
|
)
|
|
152
173
|
else:
|
|
@@ -209,6 +230,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
209
230
|
**kwargs,
|
|
210
231
|
) -> "PreTrainedModel":
|
|
211
232
|
kwargs = cls.update_kwargs(kwargs)
|
|
233
|
+
|
|
212
234
|
return cls.get_hf_class().from_pretrained(
|
|
213
235
|
model_id,
|
|
214
236
|
subfolder=subfolder,
|
|
@@ -227,6 +249,9 @@ class RBLNModel(RBLNBaseModel):
|
|
|
227
249
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
228
250
|
rbln_config: RBLNModelConfig,
|
|
229
251
|
) -> List[rebel.Runtime]:
|
|
252
|
+
if len(rbln_config.compile_cfgs) == 0:
|
|
253
|
+
return []
|
|
254
|
+
|
|
230
255
|
if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
|
|
231
256
|
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
|
232
257
|
|
|
@@ -241,31 +266,33 @@ class RBLNModel(RBLNBaseModel):
|
|
|
241
266
|
for compiled_model in compiled_models
|
|
242
267
|
]
|
|
243
268
|
|
|
244
|
-
def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs:
|
|
269
|
+
def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
|
|
245
270
|
"""
|
|
246
|
-
Defines the forward pass of
|
|
271
|
+
Defines the forward pass of `RBLNModel`. The interface mirrors HuggingFace conventions so it can act as a drop-in
|
|
272
|
+
replacement in many cases.
|
|
247
273
|
|
|
248
|
-
This method executes the compiled RBLN model on RBLN NPU devices while
|
|
249
|
-
|
|
250
|
-
|
|
274
|
+
This method executes the compiled RBLN model on RBLN NPU devices while remaining fully compatible with Hugging Face
|
|
275
|
+
Transformers and Diffusers APIs. In practice, `RBLNModel` can replace models built on `torch.nn.Module` — including
|
|
276
|
+
`transformers.PreTrainedModel` implementations and Diffusers components based on `diffusers.ModelMixin` — enabling
|
|
277
|
+
seamless integration into existing workflows.
|
|
251
278
|
|
|
252
279
|
Args:
|
|
253
|
-
|
|
280
|
+
args: Variable length argument list containing model inputs. The format matches the original
|
|
254
281
|
HuggingFace model's forward method signature (e.g., input_ids, attention_mask for
|
|
255
282
|
transformers models, or sample, timestep for diffusers models).
|
|
256
283
|
return_dict:
|
|
257
284
|
Whether to return outputs as a dictionary-like object or as a tuple. When `None`:
|
|
258
285
|
- For transformers models: Uses `self.config.use_return_dict` (typically `True`)
|
|
259
286
|
- For diffusers models: Defaults to `True`
|
|
260
|
-
|
|
287
|
+
kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
|
|
261
288
|
matching the original HuggingFace model's interface.
|
|
262
289
|
|
|
263
290
|
Returns:
|
|
264
291
|
Model outputs in the same format as the original HuggingFace model.
|
|
265
292
|
|
|
266
|
-
|
|
293
|
+
If `return_dict=True`, Returns a dictionary-like object (e.g., BaseModelOutput,
|
|
267
294
|
CausalLMOutput) with named fields such as `logits`, `hidden_states`, etc.
|
|
268
|
-
|
|
295
|
+
If `return_dict=False`, Returns a tuple containing the raw model outputs.
|
|
269
296
|
|
|
270
297
|
Note:
|
|
271
298
|
- This method maintains the exact same interface as the original HuggingFace model's forward method
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -34,7 +34,7 @@ from .utils.submodule import SubModulesMixin
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
|
-
from transformers import PreTrainedModel
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
38
38
|
|
|
39
39
|
logger = get_logger(__name__)
|
|
40
40
|
|
|
@@ -53,6 +53,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
53
53
|
config_class = AutoConfig
|
|
54
54
|
config_name = "config.json"
|
|
55
55
|
hf_library_name = "transformers"
|
|
56
|
+
_supports_non_fp32 = False
|
|
56
57
|
|
|
57
58
|
def __init__(
|
|
58
59
|
self,
|
|
@@ -70,7 +71,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
70
71
|
self.rbln_config = rbln_config
|
|
71
72
|
if not rbln_config.is_frozen():
|
|
72
73
|
raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
|
|
73
|
-
|
|
74
74
|
self.compiled_models = rbln_compiled_models
|
|
75
75
|
|
|
76
76
|
# Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
|
|
@@ -91,7 +91,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
91
91
|
|
|
92
92
|
self.device = torch.device("cpu")
|
|
93
93
|
self.training = False
|
|
94
|
-
self.dtype =
|
|
94
|
+
self.dtype = rbln_config.torch_dtype
|
|
95
95
|
|
|
96
96
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
97
97
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -314,7 +314,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
314
314
|
rbln_config,
|
|
315
315
|
model_save_dir=model_save_dir,
|
|
316
316
|
subfolder=subfolder,
|
|
317
|
-
rbln_compiled_models=
|
|
317
|
+
rbln_compiled_models=rbln_compiled_models,
|
|
318
318
|
rbln_submodules=rbln_submodules,
|
|
319
319
|
**kwargs,
|
|
320
320
|
)
|
|
@@ -342,32 +342,72 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
342
342
|
rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
|
|
343
343
|
return rbln_config, kwargs
|
|
344
344
|
|
|
345
|
+
@classmethod
|
|
346
|
+
def _is_compiled(
|
|
347
|
+
cls,
|
|
348
|
+
model_id: Union[str, Path],
|
|
349
|
+
token: Optional[Union[bool, str]] = None,
|
|
350
|
+
revision: Optional[str] = None,
|
|
351
|
+
force_download: bool = False,
|
|
352
|
+
cache_dir: Optional[str] = None,
|
|
353
|
+
subfolder: str = "",
|
|
354
|
+
local_files_only: bool = False,
|
|
355
|
+
) -> bool:
|
|
356
|
+
# Check if the model is already compiled.
|
|
357
|
+
try:
|
|
358
|
+
cls._load_compiled_model_dir(
|
|
359
|
+
model_id=model_id,
|
|
360
|
+
token=token,
|
|
361
|
+
revision=revision,
|
|
362
|
+
force_download=force_download,
|
|
363
|
+
cache_dir=cache_dir,
|
|
364
|
+
subfolder=subfolder,
|
|
365
|
+
local_files_only=local_files_only,
|
|
366
|
+
)
|
|
367
|
+
return True
|
|
368
|
+
except (FileNotFoundError, KeyError):
|
|
369
|
+
return False
|
|
370
|
+
|
|
345
371
|
@classmethod
|
|
346
372
|
def from_pretrained(
|
|
347
373
|
cls: Type["RBLNBaseModel"],
|
|
348
374
|
model_id: Union[str, Path],
|
|
349
|
-
export: bool =
|
|
375
|
+
export: Optional[bool] = None,
|
|
350
376
|
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
351
|
-
**kwargs:
|
|
377
|
+
**kwargs: Any,
|
|
352
378
|
) -> "RBLNBaseModel":
|
|
353
379
|
"""
|
|
354
380
|
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
355
381
|
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
|
356
382
|
|
|
357
383
|
Args:
|
|
358
|
-
model_id: The model id of the pre-trained model to be loaded.
|
|
359
|
-
|
|
360
|
-
|
|
384
|
+
model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
|
|
385
|
+
It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
|
386
|
+
export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
|
|
387
|
+
If None, it will be determined based on the existence of the compiled model files in the model_id.
|
|
388
|
+
rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
|
|
389
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
361
390
|
For detailed configuration options, see the specific model's configuration class documentation.
|
|
362
|
-
|
|
363
|
-
kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
391
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
364
392
|
|
|
365
393
|
Returns:
|
|
366
|
-
A RBLN model instance ready for inference on RBLN NPU devices.
|
|
394
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
367
395
|
"""
|
|
368
396
|
|
|
369
397
|
if isinstance(model_id, Path):
|
|
370
398
|
model_id = model_id.as_posix()
|
|
399
|
+
|
|
400
|
+
if export is None:
|
|
401
|
+
export = not cls._is_compiled(
|
|
402
|
+
model_id=model_id,
|
|
403
|
+
token=kwargs.get("token"),
|
|
404
|
+
revision=kwargs.get("revision"),
|
|
405
|
+
force_download=kwargs.get("force_download", False),
|
|
406
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
407
|
+
subfolder=kwargs.get("subfolder", ""),
|
|
408
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
409
|
+
)
|
|
410
|
+
|
|
371
411
|
from_pretrained_method = cls._export if export else cls._from_pretrained
|
|
372
412
|
return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
|
|
373
413
|
|
|
@@ -392,7 +432,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
392
432
|
compiled_model = rebel.compile_from_torch(
|
|
393
433
|
model,
|
|
394
434
|
input_info=rbln_compile_config.input_info,
|
|
395
|
-
fusion=rbln_compile_config.fusion,
|
|
396
435
|
npu=rbln_compile_config.npu,
|
|
397
436
|
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
|
398
437
|
**kwargs,
|
|
@@ -400,8 +439,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
400
439
|
return compiled_model
|
|
401
440
|
|
|
402
441
|
@classmethod
|
|
403
|
-
def update_rbln_config(
|
|
404
|
-
|
|
442
|
+
def update_rbln_config(
|
|
443
|
+
cls,
|
|
444
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
445
|
+
model: "PreTrainedModel",
|
|
446
|
+
model_config: "PretrainedConfig",
|
|
447
|
+
rbln_config: RBLNModelConfig,
|
|
448
|
+
) -> RBLNModelConfig:
|
|
449
|
+
rbln_config.torch_dtype = model.dtype
|
|
450
|
+
if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
|
|
451
|
+
raise NotImplementedError(
|
|
452
|
+
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
453
|
+
)
|
|
454
|
+
rbln_config = cls._update_rbln_config(
|
|
455
|
+
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
456
|
+
)
|
|
405
457
|
rbln_config.freeze()
|
|
406
458
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
407
459
|
raise NameError(
|
|
@@ -444,12 +496,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
444
496
|
|
|
445
497
|
# This method mimics the interface of torch.nn.Module.parameters()
|
|
446
498
|
# specifically for code that uses `next(model.parameters())` to infer
|
|
447
|
-
# the device or dtype. It yields a single dummy tensor on CPU with
|
|
499
|
+
# the device or dtype. It yields a single dummy tensor on CPU with model dtype.
|
|
448
500
|
|
|
449
501
|
# Warning:
|
|
450
502
|
# This does NOT yield the actual model parameters used by the RBLN runtime.
|
|
451
503
|
# Code relying on iterating through all model parameters will not work as expected.
|
|
452
|
-
yield torch.tensor([1.0], dtype=
|
|
504
|
+
yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
|
|
453
505
|
|
|
454
506
|
def __call__(self, *args, **kwargs):
|
|
455
507
|
return self.forward(*args, **kwargs)
|
|
@@ -484,9 +536,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
484
536
|
[`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
|
|
485
537
|
|
|
486
538
|
Args:
|
|
487
|
-
save_directory (
|
|
539
|
+
save_directory (Union[str, Path]):
|
|
488
540
|
Directory where to save the model file.
|
|
489
|
-
push_to_hub (
|
|
541
|
+
push_to_hub (bool):
|
|
490
542
|
Whether or not to push your model to the HuggingFace model hub after saving it.
|
|
491
543
|
|
|
492
544
|
"""
|
|
@@ -523,10 +575,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
523
575
|
# First copy everything to a temporary directory
|
|
524
576
|
shutil.copytree(real_save_dir, tmp_dir)
|
|
525
577
|
|
|
526
|
-
# If everything succeeded,
|
|
578
|
+
# If everything succeeded, move files to target directory
|
|
527
579
|
if os.path.exists(save_directory_path):
|
|
528
|
-
|
|
529
|
-
|
|
580
|
+
# Merge files from tmp_dir into existing directory
|
|
581
|
+
def _merge_dir(src_root: str, dst_root: str):
|
|
582
|
+
for name in os.listdir(src_root):
|
|
583
|
+
src_item = os.path.join(src_root, name)
|
|
584
|
+
dst_item = os.path.join(dst_root, name)
|
|
585
|
+
|
|
586
|
+
if os.path.islink(src_item) or os.path.isfile(src_item):
|
|
587
|
+
os.makedirs(os.path.dirname(dst_item), exist_ok=True)
|
|
588
|
+
if os.path.isdir(dst_item) and not os.path.islink(dst_item):
|
|
589
|
+
shutil.rmtree(dst_item)
|
|
590
|
+
os.replace(src_item, dst_item)
|
|
591
|
+
elif os.path.isdir(src_item):
|
|
592
|
+
if os.path.islink(dst_item) or os.path.isfile(dst_item):
|
|
593
|
+
os.remove(dst_item)
|
|
594
|
+
os.makedirs(dst_item, exist_ok=True)
|
|
595
|
+
_merge_dir(src_item, dst_item)
|
|
596
|
+
else:
|
|
597
|
+
# Fallback for special file types
|
|
598
|
+
os.replace(src_item, dst_item)
|
|
599
|
+
|
|
600
|
+
_merge_dir(tmp_dir, str(save_directory_path))
|
|
601
|
+
|
|
602
|
+
# Remove the temporary directory tree after merge
|
|
603
|
+
shutil.rmtree(tmp_dir)
|
|
604
|
+
else:
|
|
605
|
+
# If target doesn't exist, just rename tmp_dir to target
|
|
606
|
+
os.rename(tmp_dir, save_directory_path)
|
|
530
607
|
|
|
531
608
|
except Exception as e:
|
|
532
609
|
# Clean up the temporary directory if anything fails
|