optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -33,9 +33,9 @@ if is_cosmos_guardrail_available():
|
|
|
33
33
|
from cosmos_guardrail import CosmosSafetyChecker
|
|
34
34
|
from cosmos_guardrail.cosmos_guardrail import (
|
|
35
35
|
COSMOS_GUARDRAIL_CHECKPOINT,
|
|
36
|
-
Aegis,
|
|
37
36
|
Blocklist,
|
|
38
37
|
GuardrailRunner,
|
|
38
|
+
LlamaGuard3,
|
|
39
39
|
ModelConfig,
|
|
40
40
|
RetinaFaceFilter,
|
|
41
41
|
SafetyClassifier,
|
|
@@ -55,7 +55,7 @@ else:
|
|
|
55
55
|
|
|
56
56
|
COSMOS_GUARDRAIL_CHECKPOINT = None
|
|
57
57
|
|
|
58
|
-
class
|
|
58
|
+
class LlamaGuard3(FailToImportCosmosGuardrail): ...
|
|
59
59
|
|
|
60
60
|
class Blocklist(FailToImportCosmosGuardrail): ...
|
|
61
61
|
|
|
@@ -312,33 +312,31 @@ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
|
|
|
312
312
|
self.encoder.save_pretrained(checkpoint_id)
|
|
313
313
|
|
|
314
314
|
|
|
315
|
-
class
|
|
315
|
+
class RBLNLlamaGuard3(LlamaGuard3):
|
|
316
316
|
def __init__(
|
|
317
317
|
self,
|
|
318
318
|
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
|
319
|
-
base_model_id: str = "meta-llama/
|
|
320
|
-
aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
|
319
|
+
base_model_id: str = "meta-llama/Llama-Guard-3-8B",
|
|
321
320
|
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
|
322
321
|
) -> None:
|
|
323
322
|
if is_compiled_dir(checkpoint_id):
|
|
324
323
|
torch.nn.Module.__init__(self)
|
|
325
|
-
cache_dir = pathlib.Path(checkpoint_id) / "
|
|
324
|
+
cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
|
|
326
325
|
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
|
327
|
-
self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.
|
|
326
|
+
self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.llamaguard3)
|
|
328
327
|
|
|
329
328
|
else:
|
|
330
|
-
super().__init__(checkpoint_id, base_model_id
|
|
331
|
-
model = self.model
|
|
329
|
+
super().__init__(checkpoint_id, base_model_id)
|
|
330
|
+
model = self.model
|
|
332
331
|
del self.model
|
|
333
|
-
|
|
334
|
-
self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
|
|
332
|
+
self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.llamaguard3)
|
|
335
333
|
|
|
336
334
|
self.rbln_config = rbln_config
|
|
337
335
|
self.dtype = torch.bfloat16
|
|
338
336
|
self.device = torch.device("cpu")
|
|
339
337
|
|
|
340
338
|
def save_pretrained(self, checkpoint_id: str):
|
|
341
|
-
cache_dir = pathlib.Path(checkpoint_id) / "
|
|
339
|
+
cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
|
|
342
340
|
self.model.save_pretrained(cache_dir)
|
|
343
341
|
self.tokenizer.save_pretrained(cache_dir)
|
|
344
342
|
|
|
@@ -351,8 +349,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
351
349
|
def __init__(
|
|
352
350
|
self,
|
|
353
351
|
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
|
354
|
-
|
|
355
|
-
aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
|
352
|
+
llamaguard_model_id: str = "meta-llama/Llama-Guard-3-8B",
|
|
356
353
|
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
|
357
354
|
) -> None:
|
|
358
355
|
torch.nn.Module.__init__(self)
|
|
@@ -369,10 +366,9 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
369
366
|
self.text_guardrail = GuardrailRunner(
|
|
370
367
|
safety_models=[
|
|
371
368
|
Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
|
|
372
|
-
|
|
369
|
+
RBLNLlamaGuard3(
|
|
373
370
|
checkpoint_id=checkpoint_id,
|
|
374
|
-
base_model_id=
|
|
375
|
-
aegis_adapter=aegis_adapter_id,
|
|
371
|
+
base_model_id=llamaguard_model_id,
|
|
376
372
|
rbln_config=rbln_config,
|
|
377
373
|
),
|
|
378
374
|
]
|
|
@@ -387,7 +383,7 @@ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
|
|
387
383
|
|
|
388
384
|
def save_pretrained(self, save_dir: str):
|
|
389
385
|
for text_safety_models in self.text_guardrail.safety_models:
|
|
390
|
-
if isinstance(text_safety_models,
|
|
386
|
+
if isinstance(text_safety_models, RBLNLlamaGuard3):
|
|
391
387
|
text_safety_models.save_pretrained(save_dir)
|
|
392
388
|
|
|
393
389
|
for video_safety_models in self.video_guardrail.safety_models:
|
|
@@ -87,8 +87,38 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
|
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
89
|
rbln_config: Dict[str, Any] = {},
|
|
90
|
-
**kwargs:
|
|
90
|
+
**kwargs: Any,
|
|
91
91
|
):
|
|
92
|
+
"""
|
|
93
|
+
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
|
94
|
+
|
|
95
|
+
This method has two distinct operating modes:
|
|
96
|
+
- When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
|
|
97
|
+
- When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
|
|
98
|
+
|
|
99
|
+
It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_id (`str`):
|
|
103
|
+
The model ID or path to the pretrained model to load. Can be either:
|
|
104
|
+
|
|
105
|
+
- A model ID from the HuggingFace Hub
|
|
106
|
+
- A local path to a saved model directory
|
|
107
|
+
export:
|
|
108
|
+
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
|
109
|
+
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
|
110
|
+
safety_checker:
|
|
111
|
+
Optional custom safety checker to use instead of the default one. Only used when `export=True`.
|
|
112
|
+
rbln_config:
|
|
113
|
+
Configuration options for RBLN compilation. Can include settings for specific submodules
|
|
114
|
+
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
|
115
|
+
pipeline being compiled.
|
|
116
|
+
kwargs:
|
|
117
|
+
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
|
118
|
+
RBLN compilation process. These may include parameters specific to individual submodules
|
|
119
|
+
or the particular diffusion pipeline being used.
|
|
120
|
+
"""
|
|
121
|
+
|
|
92
122
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
93
123
|
if safety_checker is None and export:
|
|
94
124
|
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
|
@@ -87,8 +87,38 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
|
|
|
87
87
|
export: bool = False,
|
|
88
88
|
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
|
89
89
|
rbln_config: Dict[str, Any] = {},
|
|
90
|
-
**kwargs:
|
|
90
|
+
**kwargs: Any,
|
|
91
91
|
):
|
|
92
|
+
"""
|
|
93
|
+
Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
|
|
94
|
+
|
|
95
|
+
This method has two distinct operating modes:
|
|
96
|
+
- When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
|
|
97
|
+
- When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
|
|
98
|
+
|
|
99
|
+
It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model_id (`str`):
|
|
103
|
+
The model ID or path to the pretrained model to load. Can be either:
|
|
104
|
+
|
|
105
|
+
- A model ID from the HuggingFace Hub
|
|
106
|
+
- A local path to a saved model directory
|
|
107
|
+
export:
|
|
108
|
+
If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
|
|
109
|
+
If False, loads an already compiled RBLN model from `model_id` without recompilation.
|
|
110
|
+
safety_checker:
|
|
111
|
+
Optional custom safety checker to use instead of the default one. Only used when `export=True`.
|
|
112
|
+
rbln_config:
|
|
113
|
+
Configuration options for RBLN compilation. Can include settings for specific submodules
|
|
114
|
+
such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
|
|
115
|
+
pipeline being compiled.
|
|
116
|
+
kwargs:
|
|
117
|
+
Additional arguments to pass to the underlying diffusion pipeline constructor or the
|
|
118
|
+
RBLN compilation process. These may include parameters specific to individual submodules
|
|
119
|
+
or the particular diffusion pipeline being used.
|
|
120
|
+
"""
|
|
121
|
+
|
|
92
122
|
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
|
93
123
|
if safety_checker is None and export:
|
|
94
124
|
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
|
@@ -22,12 +22,7 @@ from diffusers import (
|
|
|
22
22
|
UNet2DConditionModel,
|
|
23
23
|
VQModel,
|
|
24
24
|
)
|
|
25
|
-
from transformers import
|
|
26
|
-
CLIPImageProcessor,
|
|
27
|
-
CLIPTextModelWithProjection,
|
|
28
|
-
CLIPTokenizer,
|
|
29
|
-
CLIPVisionModelWithProjection,
|
|
30
|
-
)
|
|
25
|
+
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
31
26
|
|
|
32
27
|
from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
|
|
33
28
|
from ...modeling_diffusers import RBLNDiffusionMixin
|
optimum/rbln/modeling.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, ge
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
21
21
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|
22
|
-
from transformers import
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput
|
|
24
24
|
|
|
25
25
|
from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
|
|
@@ -34,6 +34,49 @@ if TYPE_CHECKING:
|
|
|
34
34
|
logger = get_logger(__name__)
|
|
35
35
|
|
|
36
36
|
|
|
37
|
+
def _get_dtype(
|
|
38
|
+
cls,
|
|
39
|
+
dtype: Optional[Union[str, torch.dtype, dict]],
|
|
40
|
+
config: PretrainedConfig,
|
|
41
|
+
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
|
42
|
+
dtype_orig = None
|
|
43
|
+
|
|
44
|
+
if dtype is not None:
|
|
45
|
+
if isinstance(dtype, str):
|
|
46
|
+
if dtype == "auto":
|
|
47
|
+
if hasattr(config, "dtype") and config.dtype is not None:
|
|
48
|
+
dtype = config.dtype
|
|
49
|
+
else:
|
|
50
|
+
dtype = torch.get_default_dtype()
|
|
51
|
+
elif hasattr(torch, dtype):
|
|
52
|
+
dtype = getattr(torch, dtype)
|
|
53
|
+
config.dtype = dtype
|
|
54
|
+
elif isinstance(dtype, torch.dtype):
|
|
55
|
+
config.dtype = dtype
|
|
56
|
+
elif isinstance(dtype, dict):
|
|
57
|
+
for key, curr_dtype in dtype.items():
|
|
58
|
+
if hasattr(config, key):
|
|
59
|
+
value = getattr(config, key)
|
|
60
|
+
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
|
61
|
+
value.dtype = curr_dtype
|
|
62
|
+
# main torch dtype for modules that aren't part of any sub-config
|
|
63
|
+
dtype = dtype.get("")
|
|
64
|
+
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
|
|
65
|
+
config.dtype = dtype
|
|
66
|
+
if dtype is None:
|
|
67
|
+
dtype = torch.float32
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Invalid dtype: {dtype}")
|
|
70
|
+
|
|
71
|
+
dtype_orig = cls._set_default_dtype(dtype)
|
|
72
|
+
else:
|
|
73
|
+
# Use default dtype
|
|
74
|
+
default_dtype = torch.get_default_dtype()
|
|
75
|
+
config.dtype = default_dtype
|
|
76
|
+
|
|
77
|
+
return config, dtype, dtype_orig
|
|
78
|
+
|
|
79
|
+
|
|
37
80
|
class RBLNModel(RBLNBaseModel):
|
|
38
81
|
@classmethod
|
|
39
82
|
def update_kwargs(cls, kwargs):
|
|
@@ -70,6 +113,10 @@ class RBLNModel(RBLNBaseModel):
|
|
|
70
113
|
)
|
|
71
114
|
return compiled_model
|
|
72
115
|
|
|
116
|
+
@classmethod
|
|
117
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
118
|
+
return model
|
|
119
|
+
|
|
73
120
|
@classmethod
|
|
74
121
|
def from_model(
|
|
75
122
|
cls,
|
|
@@ -78,18 +125,20 @@ class RBLNModel(RBLNBaseModel):
|
|
|
78
125
|
rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
|
|
79
126
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
80
127
|
subfolder: str = "",
|
|
81
|
-
**kwargs:
|
|
128
|
+
**kwargs: Any,
|
|
82
129
|
) -> "RBLNModel":
|
|
83
130
|
"""
|
|
84
131
|
Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
|
|
85
132
|
This method performs the actual model conversion and compilation process.
|
|
86
133
|
|
|
87
134
|
Args:
|
|
88
|
-
model: The PyTorch model to be compiled.
|
|
89
|
-
|
|
135
|
+
model (PreTrainedModel): The PyTorch model to be compiled.
|
|
136
|
+
The object must be an instance of the HuggingFace transformers PreTrainedModel class.
|
|
137
|
+
config (Optional[PretrainedConfig]): The configuration object associated with the model.
|
|
138
|
+
rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
|
|
139
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
90
140
|
For detailed configuration options, see the specific model's configuration class documentation.
|
|
91
|
-
|
|
92
|
-
kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
141
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
93
142
|
|
|
94
143
|
The method performs the following steps:
|
|
95
144
|
|
|
@@ -99,8 +148,10 @@ class RBLNModel(RBLNBaseModel):
|
|
|
99
148
|
4. Saves the compiled model and configurations
|
|
100
149
|
|
|
101
150
|
Returns:
|
|
102
|
-
A RBLN model instance ready for inference on RBLN NPU devices.
|
|
151
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
103
152
|
"""
|
|
153
|
+
|
|
154
|
+
model = cls._reconstruct_model_if_needed(model)
|
|
104
155
|
preprocessors = kwargs.pop("preprocessors", [])
|
|
105
156
|
rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
|
|
106
157
|
|
|
@@ -119,9 +170,6 @@ class RBLNModel(RBLNBaseModel):
|
|
|
119
170
|
# Save configs
|
|
120
171
|
if config is None:
|
|
121
172
|
config = model.config
|
|
122
|
-
# remote_config
|
|
123
|
-
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
|
124
|
-
config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
|
|
125
173
|
|
|
126
174
|
if hasattr(model, "can_generate") and model.can_generate():
|
|
127
175
|
import json
|
|
@@ -147,6 +195,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
147
195
|
model=model,
|
|
148
196
|
model_save_dir=save_dir,
|
|
149
197
|
rbln_config=rbln_config,
|
|
198
|
+
preprocessors=preprocessors,
|
|
150
199
|
**kwargs,
|
|
151
200
|
)
|
|
152
201
|
else:
|
|
@@ -209,6 +258,7 @@ class RBLNModel(RBLNBaseModel):
|
|
|
209
258
|
**kwargs,
|
|
210
259
|
) -> "PreTrainedModel":
|
|
211
260
|
kwargs = cls.update_kwargs(kwargs)
|
|
261
|
+
|
|
212
262
|
return cls.get_hf_class().from_pretrained(
|
|
213
263
|
model_id,
|
|
214
264
|
subfolder=subfolder,
|
|
@@ -241,31 +291,33 @@ class RBLNModel(RBLNBaseModel):
|
|
|
241
291
|
for compiled_model in compiled_models
|
|
242
292
|
]
|
|
243
293
|
|
|
244
|
-
def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs:
|
|
294
|
+
def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
|
|
245
295
|
"""
|
|
246
|
-
Defines the forward pass of
|
|
296
|
+
Defines the forward pass of `RBLNModel`. The interface mirrors HuggingFace conventions so it can act as a drop-in
|
|
297
|
+
replacement in many cases.
|
|
247
298
|
|
|
248
|
-
This method executes the compiled RBLN model on RBLN NPU devices while
|
|
249
|
-
|
|
250
|
-
|
|
299
|
+
This method executes the compiled RBLN model on RBLN NPU devices while remaining fully compatible with Hugging Face
|
|
300
|
+
Transformers and Diffusers APIs. In practice, `RBLNModel` can replace models built on `torch.nn.Module` — including
|
|
301
|
+
`transformers.PreTrainedModel` implementations and Diffusers components based on `diffusers.ModelMixin` — enabling
|
|
302
|
+
seamless integration into existing workflows.
|
|
251
303
|
|
|
252
304
|
Args:
|
|
253
|
-
|
|
305
|
+
args: Variable length argument list containing model inputs. The format matches the original
|
|
254
306
|
HuggingFace model's forward method signature (e.g., input_ids, attention_mask for
|
|
255
307
|
transformers models, or sample, timestep for diffusers models).
|
|
256
308
|
return_dict:
|
|
257
309
|
Whether to return outputs as a dictionary-like object or as a tuple. When `None`:
|
|
258
310
|
- For transformers models: Uses `self.config.use_return_dict` (typically `True`)
|
|
259
311
|
- For diffusers models: Defaults to `True`
|
|
260
|
-
|
|
312
|
+
kwargs: Arbitrary keyword arguments containing additional model inputs and parameters,
|
|
261
313
|
matching the original HuggingFace model's interface.
|
|
262
314
|
|
|
263
315
|
Returns:
|
|
264
316
|
Model outputs in the same format as the original HuggingFace model.
|
|
265
317
|
|
|
266
|
-
|
|
318
|
+
If `return_dict=True`, Returns a dictionary-like object (e.g., BaseModelOutput,
|
|
267
319
|
CausalLMOutput) with named fields such as `logits`, `hidden_states`, etc.
|
|
268
|
-
|
|
320
|
+
If `return_dict=False`, Returns a tuple containing the raw model outputs.
|
|
269
321
|
|
|
270
322
|
Note:
|
|
271
323
|
- This method maintains the exact same interface as the original HuggingFace model's forward method
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -34,7 +34,7 @@ from .utils.submodule import SubModulesMixin
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
|
-
from transformers import PreTrainedModel
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
38
38
|
|
|
39
39
|
logger = get_logger(__name__)
|
|
40
40
|
|
|
@@ -53,6 +53,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
53
53
|
config_class = AutoConfig
|
|
54
54
|
config_name = "config.json"
|
|
55
55
|
hf_library_name = "transformers"
|
|
56
|
+
_supports_non_fp32 = False
|
|
56
57
|
|
|
57
58
|
def __init__(
|
|
58
59
|
self,
|
|
@@ -91,7 +92,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
91
92
|
|
|
92
93
|
self.device = torch.device("cpu")
|
|
93
94
|
self.training = False
|
|
94
|
-
self.dtype =
|
|
95
|
+
self.dtype = rbln_config.torch_dtype
|
|
95
96
|
|
|
96
97
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
97
98
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -314,7 +315,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
314
315
|
rbln_config,
|
|
315
316
|
model_save_dir=model_save_dir,
|
|
316
317
|
subfolder=subfolder,
|
|
317
|
-
rbln_compiled_models=
|
|
318
|
+
rbln_compiled_models=rbln_compiled_models,
|
|
318
319
|
rbln_submodules=rbln_submodules,
|
|
319
320
|
**kwargs,
|
|
320
321
|
)
|
|
@@ -342,32 +343,72 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
342
343
|
rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
|
|
343
344
|
return rbln_config, kwargs
|
|
344
345
|
|
|
346
|
+
@classmethod
|
|
347
|
+
def _is_compiled(
|
|
348
|
+
cls,
|
|
349
|
+
model_id: Union[str, Path],
|
|
350
|
+
token: Optional[Union[bool, str]] = None,
|
|
351
|
+
revision: Optional[str] = None,
|
|
352
|
+
force_download: bool = False,
|
|
353
|
+
cache_dir: Optional[str] = None,
|
|
354
|
+
subfolder: str = "",
|
|
355
|
+
local_files_only: bool = False,
|
|
356
|
+
) -> bool:
|
|
357
|
+
# Check if the model is already compiled.
|
|
358
|
+
try:
|
|
359
|
+
cls._load_compiled_model_dir(
|
|
360
|
+
model_id=model_id,
|
|
361
|
+
token=token,
|
|
362
|
+
revision=revision,
|
|
363
|
+
force_download=force_download,
|
|
364
|
+
cache_dir=cache_dir,
|
|
365
|
+
subfolder=subfolder,
|
|
366
|
+
local_files_only=local_files_only,
|
|
367
|
+
)
|
|
368
|
+
return True
|
|
369
|
+
except (FileNotFoundError, KeyError):
|
|
370
|
+
return False
|
|
371
|
+
|
|
345
372
|
@classmethod
|
|
346
373
|
def from_pretrained(
|
|
347
374
|
cls: Type["RBLNBaseModel"],
|
|
348
375
|
model_id: Union[str, Path],
|
|
349
|
-
export: bool =
|
|
376
|
+
export: Optional[bool] = None,
|
|
350
377
|
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
351
|
-
**kwargs:
|
|
378
|
+
**kwargs: Any,
|
|
352
379
|
) -> "RBLNBaseModel":
|
|
353
380
|
"""
|
|
354
381
|
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
355
382
|
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
|
356
383
|
|
|
357
384
|
Args:
|
|
358
|
-
model_id: The model id of the pre-trained model to be loaded.
|
|
359
|
-
|
|
360
|
-
|
|
385
|
+
model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
|
|
386
|
+
It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
|
387
|
+
export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
|
|
388
|
+
If None, it will be determined based on the existence of the compiled model files in the model_id.
|
|
389
|
+
rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
|
|
390
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
361
391
|
For detailed configuration options, see the specific model's configuration class documentation.
|
|
362
|
-
|
|
363
|
-
kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
392
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
364
393
|
|
|
365
394
|
Returns:
|
|
366
|
-
A RBLN model instance ready for inference on RBLN NPU devices.
|
|
395
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
367
396
|
"""
|
|
368
397
|
|
|
369
398
|
if isinstance(model_id, Path):
|
|
370
399
|
model_id = model_id.as_posix()
|
|
400
|
+
|
|
401
|
+
if export is None:
|
|
402
|
+
export = not cls._is_compiled(
|
|
403
|
+
model_id=model_id,
|
|
404
|
+
token=kwargs.get("token"),
|
|
405
|
+
revision=kwargs.get("revision"),
|
|
406
|
+
force_download=kwargs.get("force_download", False),
|
|
407
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
408
|
+
subfolder=kwargs.get("subfolder", ""),
|
|
409
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
410
|
+
)
|
|
411
|
+
|
|
371
412
|
from_pretrained_method = cls._export if export else cls._from_pretrained
|
|
372
413
|
return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
|
|
373
414
|
|
|
@@ -392,7 +433,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
392
433
|
compiled_model = rebel.compile_from_torch(
|
|
393
434
|
model,
|
|
394
435
|
input_info=rbln_compile_config.input_info,
|
|
395
|
-
fusion=rbln_compile_config.fusion,
|
|
396
436
|
npu=rbln_compile_config.npu,
|
|
397
437
|
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
|
398
438
|
**kwargs,
|
|
@@ -400,8 +440,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
400
440
|
return compiled_model
|
|
401
441
|
|
|
402
442
|
@classmethod
|
|
403
|
-
def update_rbln_config(
|
|
404
|
-
|
|
443
|
+
def update_rbln_config(
|
|
444
|
+
cls,
|
|
445
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
446
|
+
model: "PreTrainedModel",
|
|
447
|
+
model_config: "PretrainedConfig",
|
|
448
|
+
rbln_config: RBLNModelConfig,
|
|
449
|
+
) -> RBLNModelConfig:
|
|
450
|
+
rbln_config.torch_dtype = model.dtype
|
|
451
|
+
if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
|
|
452
|
+
raise NotImplementedError(
|
|
453
|
+
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
454
|
+
)
|
|
455
|
+
rbln_config = cls._update_rbln_config(
|
|
456
|
+
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
457
|
+
)
|
|
405
458
|
rbln_config.freeze()
|
|
406
459
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
407
460
|
raise NameError(
|
|
@@ -444,12 +497,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
444
497
|
|
|
445
498
|
# This method mimics the interface of torch.nn.Module.parameters()
|
|
446
499
|
# specifically for code that uses `next(model.parameters())` to infer
|
|
447
|
-
# the device or dtype. It yields a single dummy tensor on CPU with
|
|
500
|
+
# the device or dtype. It yields a single dummy tensor on CPU with model dtype.
|
|
448
501
|
|
|
449
502
|
# Warning:
|
|
450
503
|
# This does NOT yield the actual model parameters used by the RBLN runtime.
|
|
451
504
|
# Code relying on iterating through all model parameters will not work as expected.
|
|
452
|
-
yield torch.tensor([1.0], dtype=
|
|
505
|
+
yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
|
|
453
506
|
|
|
454
507
|
def __call__(self, *args, **kwargs):
|
|
455
508
|
return self.forward(*args, **kwargs)
|
|
@@ -484,9 +537,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
484
537
|
[`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
|
|
485
538
|
|
|
486
539
|
Args:
|
|
487
|
-
save_directory (
|
|
540
|
+
save_directory (Union[str, Path]):
|
|
488
541
|
Directory where to save the model file.
|
|
489
|
-
push_to_hub (
|
|
542
|
+
push_to_hub (bool):
|
|
490
543
|
Whether or not to push your model to the HuggingFace model hub after saving it.
|
|
491
544
|
|
|
492
545
|
"""
|
|
@@ -523,10 +576,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
523
576
|
# First copy everything to a temporary directory
|
|
524
577
|
shutil.copytree(real_save_dir, tmp_dir)
|
|
525
578
|
|
|
526
|
-
# If everything succeeded,
|
|
579
|
+
# If everything succeeded, move files to target directory
|
|
527
580
|
if os.path.exists(save_directory_path):
|
|
528
|
-
|
|
529
|
-
|
|
581
|
+
# Merge files from tmp_dir into existing directory
|
|
582
|
+
def _merge_dir(src_root: str, dst_root: str):
|
|
583
|
+
for name in os.listdir(src_root):
|
|
584
|
+
src_item = os.path.join(src_root, name)
|
|
585
|
+
dst_item = os.path.join(dst_root, name)
|
|
586
|
+
|
|
587
|
+
if os.path.islink(src_item) or os.path.isfile(src_item):
|
|
588
|
+
os.makedirs(os.path.dirname(dst_item), exist_ok=True)
|
|
589
|
+
if os.path.isdir(dst_item) and not os.path.islink(dst_item):
|
|
590
|
+
shutil.rmtree(dst_item)
|
|
591
|
+
os.replace(src_item, dst_item)
|
|
592
|
+
elif os.path.isdir(src_item):
|
|
593
|
+
if os.path.islink(dst_item) or os.path.isfile(dst_item):
|
|
594
|
+
os.remove(dst_item)
|
|
595
|
+
os.makedirs(dst_item, exist_ok=True)
|
|
596
|
+
_merge_dir(src_item, dst_item)
|
|
597
|
+
else:
|
|
598
|
+
# Fallback for special file types
|
|
599
|
+
os.replace(src_item, dst_item)
|
|
600
|
+
|
|
601
|
+
_merge_dir(tmp_dir, str(save_directory_path))
|
|
602
|
+
|
|
603
|
+
# Remove the temporary directory tree after merge
|
|
604
|
+
shutil.rmtree(tmp_dir)
|
|
605
|
+
else:
|
|
606
|
+
# If target doesn't exist, just rename tmp_dir to target
|
|
607
|
+
os.rename(tmp_dir, save_directory_path)
|
|
530
608
|
|
|
531
609
|
except Exception as e:
|
|
532
610
|
# Clean up the temporary directory if anything fails
|