optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,455 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import os
|
16
|
+
import pathlib
|
17
|
+
from functools import partial
|
18
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
19
|
+
from unittest.mock import patch
|
20
|
+
|
21
|
+
import rebel
|
22
|
+
import torch
|
23
|
+
from diffusers.utils import is_cosmos_guardrail_available
|
24
|
+
from huggingface_hub import snapshot_download
|
25
|
+
from transformers import AutoTokenizer, SiglipProcessor
|
26
|
+
|
27
|
+
from .... import RBLNAutoModelForCausalLM, RBLNSiglipVisionModel
|
28
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime, UnavailableRuntime
|
29
|
+
from .configuration_cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
|
30
|
+
|
31
|
+
|
32
|
+
if is_cosmos_guardrail_available():
|
33
|
+
from cosmos_guardrail import CosmosSafetyChecker
|
34
|
+
from cosmos_guardrail.cosmos_guardrail import (
|
35
|
+
COSMOS_GUARDRAIL_CHECKPOINT,
|
36
|
+
Aegis,
|
37
|
+
Blocklist,
|
38
|
+
GuardrailRunner,
|
39
|
+
ModelConfig,
|
40
|
+
RetinaFaceFilter,
|
41
|
+
SafetyClassifier,
|
42
|
+
SigLIPEncoder,
|
43
|
+
VideoContentSafetyFilter,
|
44
|
+
VideoSafetyModel,
|
45
|
+
)
|
46
|
+
from retinaface.data import cfg_re50
|
47
|
+
|
48
|
+
COSMOS_AVAILABLE = True
|
49
|
+
else:
|
50
|
+
COSMOS_AVAILABLE = False
|
51
|
+
|
52
|
+
class FailToImportCosmosGuardrail(torch.nn.Module): ...
|
53
|
+
|
54
|
+
class CosmosSafetyChecker(FailToImportCosmosGuardrail): ...
|
55
|
+
|
56
|
+
COSMOS_GUARDRAIL_CHECKPOINT = None
|
57
|
+
|
58
|
+
class Aegis(FailToImportCosmosGuardrail): ...
|
59
|
+
|
60
|
+
class Blocklist(FailToImportCosmosGuardrail): ...
|
61
|
+
|
62
|
+
class GuardrailRunner(FailToImportCosmosGuardrail): ...
|
63
|
+
|
64
|
+
class ModelConfig(FailToImportCosmosGuardrail): ...
|
65
|
+
|
66
|
+
class RetinaFaceFilter(FailToImportCosmosGuardrail): ...
|
67
|
+
|
68
|
+
class SafetyClassifier(FailToImportCosmosGuardrail): ...
|
69
|
+
|
70
|
+
class SigLIPEncoder(FailToImportCosmosGuardrail): ...
|
71
|
+
|
72
|
+
class VideoContentSafetyFilter(FailToImportCosmosGuardrail): ...
|
73
|
+
|
74
|
+
class VideoSafetyModel(FailToImportCosmosGuardrail): ...
|
75
|
+
|
76
|
+
cfg_re50 = None
|
77
|
+
|
78
|
+
|
79
|
+
def is_compiled_dir(dir: str) -> bool:
|
80
|
+
# walk directory and check if there is any *.rbln files in that dir.
|
81
|
+
if not os.path.exists(dir):
|
82
|
+
return False
|
83
|
+
|
84
|
+
for root, dirs, files in os.walk(dir):
|
85
|
+
for file in files:
|
86
|
+
if file.endswith(".rbln"):
|
87
|
+
return True
|
88
|
+
return False
|
89
|
+
|
90
|
+
|
91
|
+
def get_image_features(
|
92
|
+
self,
|
93
|
+
pixel_values: torch.Tensor,
|
94
|
+
return_dict: bool = True,
|
95
|
+
output_attentions: bool = False,
|
96
|
+
output_hidden_states: bool = False,
|
97
|
+
interpolate_pos_encoding: bool = False,
|
98
|
+
):
|
99
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
100
|
+
output_hidden_states = (
|
101
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
102
|
+
)
|
103
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
104
|
+
|
105
|
+
return self(
|
106
|
+
pixel_values,
|
107
|
+
return_dict=return_dict,
|
108
|
+
output_attentions=output_attentions,
|
109
|
+
output_hidden_states=output_hidden_states,
|
110
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
111
|
+
)[1]
|
112
|
+
|
113
|
+
|
114
|
+
class RBLNSigLIPEncoder(SigLIPEncoder):
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
model_name: str = "google/siglip-so400m-patch14-384",
|
118
|
+
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
119
|
+
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
120
|
+
):
|
121
|
+
torch.nn.Module.__init__(self)
|
122
|
+
if is_compiled_dir(checkpoint_id):
|
123
|
+
self.checkpoint_dir = (
|
124
|
+
pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder"
|
125
|
+
).as_posix()
|
126
|
+
self.processor = SiglipProcessor.from_pretrained(self.checkpoint_dir)
|
127
|
+
|
128
|
+
# We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
|
129
|
+
self.model = RBLNSiglipVisionModel.from_pretrained(
|
130
|
+
self.checkpoint_dir,
|
131
|
+
rbln_device=rbln_config.siglip_encoder.device,
|
132
|
+
rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
|
133
|
+
rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
|
134
|
+
rbln_optimize_host_memory=rbln_config.siglip_encoder.optimize_host_memory,
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
super().__init__(model_name, checkpoint_id)
|
138
|
+
model = self.model
|
139
|
+
del self.model
|
140
|
+
self.model = RBLNSiglipVisionModel.from_model(
|
141
|
+
model,
|
142
|
+
rbln_device=rbln_config.siglip_encoder.device,
|
143
|
+
rbln_image_size=rbln_config.siglip_encoder.image_size,
|
144
|
+
rbln_npu=rbln_config.siglip_encoder.npu,
|
145
|
+
rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
|
146
|
+
rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
|
147
|
+
rbln_optimize_host_memory=rbln_config.siglip_encoder.optimize_host_memory,
|
148
|
+
)
|
149
|
+
self.rbln_config = rbln_config
|
150
|
+
|
151
|
+
# Override get_image_features to return pooler_output
|
152
|
+
self.model.get_image_features = lambda *args, **kwargs: get_image_features(self.model, *args, **kwargs)
|
153
|
+
|
154
|
+
def save_pretrained(self, checkpoint_id: str):
|
155
|
+
cache_dir = (pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder").as_posix()
|
156
|
+
self.model.save_pretrained(cache_dir)
|
157
|
+
self.processor.save_pretrained(cache_dir)
|
158
|
+
|
159
|
+
|
160
|
+
class RBLNRetinaFaceFilter(RetinaFaceFilter):
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
164
|
+
batch_size: int = 1,
|
165
|
+
confidence_threshold: float = 0.7,
|
166
|
+
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
167
|
+
):
|
168
|
+
torch.nn.Module.__init__(self)
|
169
|
+
if is_compiled_dir(checkpoint_id):
|
170
|
+
self.compiled_model = rebel.RBLNCompiledModel(
|
171
|
+
pathlib.Path(checkpoint_id) / "face_blur_filter" / "retinaface.rbln"
|
172
|
+
)
|
173
|
+
self.cfg = cfg_re50
|
174
|
+
self.batch_size = batch_size
|
175
|
+
self.confidence_threshold = confidence_threshold
|
176
|
+
self.cfg["pretrain"] = False
|
177
|
+
else:
|
178
|
+
with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
|
179
|
+
super().__init__(checkpoint_id)
|
180
|
+
net = self.net
|
181
|
+
del self.net
|
182
|
+
self.compiled_model = rebel.compile_from_torch(
|
183
|
+
net,
|
184
|
+
input_info=[
|
185
|
+
(
|
186
|
+
"frames",
|
187
|
+
[
|
188
|
+
self.batch_size,
|
189
|
+
3,
|
190
|
+
rbln_config.face_blur_filter.image_size[0],
|
191
|
+
rbln_config.face_blur_filter.image_size[1],
|
192
|
+
],
|
193
|
+
"float32",
|
194
|
+
)
|
195
|
+
],
|
196
|
+
npu=rbln_config.face_blur_filter.npu,
|
197
|
+
)
|
198
|
+
|
199
|
+
self.rbln_config = rbln_config
|
200
|
+
|
201
|
+
try:
|
202
|
+
runtime = (
|
203
|
+
rebel.Runtime(
|
204
|
+
self.compiled_model,
|
205
|
+
tensor_type="pt",
|
206
|
+
device=self.rbln_config.face_blur_filter.device,
|
207
|
+
activate_profiler=rbln_config.face_blur_filter.activate_profiler,
|
208
|
+
)
|
209
|
+
if self.rbln_config.face_blur_filter.create_runtimes
|
210
|
+
else UnavailableRuntime()
|
211
|
+
)
|
212
|
+
except rebel.core.exception.RBLNRuntimeError as e:
|
213
|
+
error_msg = (
|
214
|
+
f"\nFailed to create RBLN runtime: {str(e)}\n\n"
|
215
|
+
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
216
|
+
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
217
|
+
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
218
|
+
f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
|
219
|
+
f"Make sure your NPU is properly installed and operational."
|
220
|
+
)
|
221
|
+
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
222
|
+
|
223
|
+
self.net = RBLNPytorchRuntime(runtime)
|
224
|
+
|
225
|
+
def save_pretrained(self, checkpoint_id: str):
|
226
|
+
cache_path = pathlib.Path(checkpoint_id) / "face_blur_filter"
|
227
|
+
cache_path.mkdir(parents=True, exist_ok=True)
|
228
|
+
self.compiled_model.save(cache_path / "retinaface.rbln")
|
229
|
+
|
230
|
+
|
231
|
+
class RBLNVideoSafetyModel(VideoSafetyModel):
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
config: ModelConfig,
|
235
|
+
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
236
|
+
rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
|
237
|
+
):
|
238
|
+
torch.nn.Module.__init__(self)
|
239
|
+
self.config = config
|
240
|
+
self.num_classes = config.num_classes
|
241
|
+
self.rbln_config = rbln_config
|
242
|
+
|
243
|
+
if is_compiled_dir(checkpoint_id):
|
244
|
+
self.compiled_model = rebel.RBLNCompiledModel(
|
245
|
+
pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "safety_filter.rbln"
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
# Load model from checkpoint
|
249
|
+
network = SafetyClassifier(
|
250
|
+
input_size=self.rbln_config.video_safety_model.input_size, num_classes=self.num_classes
|
251
|
+
)
|
252
|
+
network.eval()
|
253
|
+
|
254
|
+
checkpoint_dir = snapshot_download(checkpoint_id)
|
255
|
+
checkpoint_dir = (pathlib.Path(checkpoint_dir) / "video_content_safety_filter").as_posix()
|
256
|
+
|
257
|
+
safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt")
|
258
|
+
checkpoint = torch.load(safety_filter_local_path, weights_only=True)
|
259
|
+
network.load_state_dict({k.replace("network.", ""): v for k, v in checkpoint["model"].items()})
|
260
|
+
|
261
|
+
self.compiled_model = rebel.compile_from_torch(
|
262
|
+
network,
|
263
|
+
input_info=[
|
264
|
+
(
|
265
|
+
"data",
|
266
|
+
[
|
267
|
+
self.rbln_config.video_safety_model.batch_size,
|
268
|
+
self.rbln_config.video_safety_model.input_size,
|
269
|
+
],
|
270
|
+
"float32",
|
271
|
+
)
|
272
|
+
],
|
273
|
+
npu=self.rbln_config.video_safety_model.npu,
|
274
|
+
)
|
275
|
+
|
276
|
+
try:
|
277
|
+
runtime = (
|
278
|
+
rebel.Runtime(
|
279
|
+
self.compiled_model,
|
280
|
+
tensor_type="pt",
|
281
|
+
device=self.rbln_config.video_safety_model.device,
|
282
|
+
activate_profiler=rbln_config.video_safety_model.activate_profiler,
|
283
|
+
)
|
284
|
+
if self.rbln_config.video_safety_model.create_runtimes
|
285
|
+
else UnavailableRuntime()
|
286
|
+
)
|
287
|
+
except rebel.core.exception.RBLNRuntimeError as e:
|
288
|
+
error_msg = (
|
289
|
+
f"\nFailed to create RBLN runtime: {str(e)}\n\n"
|
290
|
+
f"If you only need to compile the model without loading it to NPU, you can use:\n"
|
291
|
+
f" from_pretrained(..., rbln_create_runtimes=False) or\n"
|
292
|
+
f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
|
293
|
+
f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
|
294
|
+
f"Make sure your NPU is properly installed and operational."
|
295
|
+
)
|
296
|
+
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
297
|
+
|
298
|
+
self.network = RBLNPytorchRuntime(runtime)
|
299
|
+
|
300
|
+
def save_pretrained(self, checkpoint_id: str):
|
301
|
+
cache_path = pathlib.Path(checkpoint_id) / "video_content_safety_filter"
|
302
|
+
cache_path.mkdir(parents=True, exist_ok=True)
|
303
|
+
self.compiled_model.save(cache_path / "safety_filter.rbln")
|
304
|
+
|
305
|
+
def parameters(self):
|
306
|
+
yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
|
307
|
+
|
308
|
+
|
309
|
+
class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
|
310
|
+
def __init__(
|
311
|
+
self,
|
312
|
+
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
313
|
+
rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
|
314
|
+
):
|
315
|
+
torch.nn.Module.__init__(self)
|
316
|
+
self.rbln_config = rbln_config
|
317
|
+
self.encoder = RBLNSigLIPEncoder(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
|
318
|
+
|
319
|
+
model_config = ModelConfig(input_size=1152, num_classes=7)
|
320
|
+
self.model = RBLNVideoSafetyModel(model_config, checkpoint_id=checkpoint_id, rbln_config=rbln_config)
|
321
|
+
|
322
|
+
def save_pretrained(self, checkpoint_id: str):
|
323
|
+
self.model.save_pretrained(checkpoint_id)
|
324
|
+
self.encoder.save_pretrained(checkpoint_id)
|
325
|
+
|
326
|
+
|
327
|
+
class RBLNAegis(Aegis):
|
328
|
+
def __init__(
|
329
|
+
self,
|
330
|
+
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
331
|
+
base_model_id: str = "meta-llama/LlamaGuard-7b",
|
332
|
+
aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
333
|
+
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
334
|
+
) -> None:
|
335
|
+
if is_compiled_dir(checkpoint_id):
|
336
|
+
torch.nn.Module.__init__(self)
|
337
|
+
cache_dir = pathlib.Path(checkpoint_id) / "aegis"
|
338
|
+
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
339
|
+
self.model = RBLNAutoModelForCausalLM.from_pretrained(
|
340
|
+
cache_dir,
|
341
|
+
rbln_device=rbln_config.aegis.device,
|
342
|
+
rbln_create_runtimes=rbln_config.aegis.create_runtimes,
|
343
|
+
rbln_activate_profiler=rbln_config.aegis.activate_profiler,
|
344
|
+
rbln_optimize_host_memory=rbln_config.aegis.optimize_host_memory,
|
345
|
+
)
|
346
|
+
|
347
|
+
else:
|
348
|
+
super().__init__(checkpoint_id, base_model_id, aegis_adapter)
|
349
|
+
model = self.model.merge_and_unload() # peft merge
|
350
|
+
del self.model
|
351
|
+
|
352
|
+
self.model = RBLNAutoModelForCausalLM.from_model(
|
353
|
+
model,
|
354
|
+
rbln_tensor_parallel_size=4,
|
355
|
+
rbln_device=rbln_config.aegis.device,
|
356
|
+
rbln_create_runtimes=rbln_config.aegis.create_runtimes,
|
357
|
+
rbln_npu=rbln_config.aegis.npu,
|
358
|
+
rbln_activate_profiler=rbln_config.aegis.activate_profiler,
|
359
|
+
rbln_optimize_host_memory=rbln_config.aegis.optimize_host_memory,
|
360
|
+
)
|
361
|
+
|
362
|
+
self.rbln_config = rbln_config
|
363
|
+
self.dtype = torch.bfloat16
|
364
|
+
self.device = torch.device("cpu")
|
365
|
+
|
366
|
+
def save_pretrained(self, checkpoint_id: str):
|
367
|
+
cache_dir = pathlib.Path(checkpoint_id) / "aegis"
|
368
|
+
self.model.save_pretrained(cache_dir)
|
369
|
+
self.tokenizer.save_pretrained(cache_dir)
|
370
|
+
|
371
|
+
|
372
|
+
class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
|
373
|
+
"""
|
374
|
+
RBLN-accelerated implementation of Cosmos Safety Checker.
|
375
|
+
"""
|
376
|
+
|
377
|
+
def __init__(
|
378
|
+
self,
|
379
|
+
checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
|
380
|
+
aegis_model_id: str = "meta-llama/LlamaGuard-7b",
|
381
|
+
aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
|
382
|
+
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
383
|
+
) -> None:
|
384
|
+
torch.nn.Module.__init__(self)
|
385
|
+
if not COSMOS_AVAILABLE:
|
386
|
+
raise ImportError(
|
387
|
+
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
|
388
|
+
)
|
389
|
+
|
390
|
+
if rbln_config is None:
|
391
|
+
rbln_config = RBLNCosmosSafetyCheckerConfig()
|
392
|
+
elif isinstance(rbln_config, dict):
|
393
|
+
rbln_config = RBLNCosmosSafetyCheckerConfig(**rbln_config)
|
394
|
+
|
395
|
+
self.text_guardrail = GuardrailRunner(
|
396
|
+
safety_models=[
|
397
|
+
Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
|
398
|
+
RBLNAegis(
|
399
|
+
checkpoint_id=checkpoint_id,
|
400
|
+
base_model_id=aegis_model_id,
|
401
|
+
aegis_adapter=aegis_adapter_id,
|
402
|
+
rbln_config=rbln_config,
|
403
|
+
),
|
404
|
+
]
|
405
|
+
)
|
406
|
+
|
407
|
+
self.video_guardrail = GuardrailRunner(
|
408
|
+
safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
|
409
|
+
postprocessors=[RBLNRetinaFaceFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
|
410
|
+
)
|
411
|
+
|
412
|
+
self.rbln_config = rbln_config
|
413
|
+
|
414
|
+
def save_pretrained(self, save_dir: str):
|
415
|
+
for text_safety_models in self.text_guardrail.safety_models:
|
416
|
+
if isinstance(text_safety_models, RBLNAegis):
|
417
|
+
text_safety_models.save_pretrained(save_dir)
|
418
|
+
|
419
|
+
for video_safety_models in self.video_guardrail.safety_models:
|
420
|
+
if isinstance(video_safety_models, RBLNVideoContentSafetyFilter):
|
421
|
+
video_safety_models.save_pretrained(save_dir)
|
422
|
+
|
423
|
+
for postprocessors in self.video_guardrail.postprocessors:
|
424
|
+
if isinstance(postprocessors, RBLNRetinaFaceFilter):
|
425
|
+
postprocessors.save_pretrained(save_dir)
|
426
|
+
|
427
|
+
self.rbln_config._frozen = True # Ad-hoc to save config
|
428
|
+
self.rbln_config.save(save_dir)
|
429
|
+
|
430
|
+
@classmethod
|
431
|
+
def from_pretrained(
|
432
|
+
cls,
|
433
|
+
checkpoint_id: str,
|
434
|
+
rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
|
435
|
+
subfolder: Optional[str] = None,
|
436
|
+
export: Optional[bool] = True,
|
437
|
+
**kwargs,
|
438
|
+
):
|
439
|
+
rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
|
440
|
+
|
441
|
+
if len(kwargs) > 0:
|
442
|
+
raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
|
443
|
+
|
444
|
+
if subfolder is not None:
|
445
|
+
checkpoint_id = os.path.join(checkpoint_id, subfolder)
|
446
|
+
|
447
|
+
return cls(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
|
448
|
+
|
449
|
+
@classmethod
|
450
|
+
def prepare_rbln_config(
|
451
|
+
cls, rbln_config: Optional[Union[Dict[str, Any], RBLNCosmosSafetyCheckerConfig]] = None, **kwargs
|
452
|
+
) -> Tuple[RBLNCosmosSafetyCheckerConfig, Dict[str, Any]]:
|
453
|
+
# Extract rbln-config from kwargs and convert it to RBLNCosmosSafetyCheckerConfig.
|
454
|
+
rbln_config, kwargs = RBLNCosmosSafetyCheckerConfig.initialize_from_kwargs(rbln_config, **kwargs)
|
455
|
+
return rbln_config, kwargs
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, Optional
|
17
|
+
|
18
|
+
from diffusers import CosmosTextToWorldPipeline
|
19
|
+
from diffusers.schedulers import EDMEulerScheduler
|
20
|
+
from transformers import T5TokenizerFast
|
21
|
+
|
22
|
+
from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
|
23
|
+
from ....utils.logging import get_logger
|
24
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
25
|
+
from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
|
26
|
+
from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
|
27
|
+
from .cosmos_guardrail import RBLNCosmosSafetyChecker
|
28
|
+
|
29
|
+
|
30
|
+
logger = get_logger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipeline):
|
34
|
+
"""
|
35
|
+
RBLN-accelerated implementation of Cosmos Text to World pipeline for text-to-video generation.
|
36
|
+
|
37
|
+
This pipeline compiles Cosmos Text to World models to run efficiently on RBLN NPUs, enabling high-performance
|
38
|
+
inference for generating videos with distinctive artistic style and enhanced visual quality.
|
39
|
+
"""
|
40
|
+
|
41
|
+
original_class = CosmosTextToWorldPipeline
|
42
|
+
_submodules = ["text_encoder", "transformer", "vae"]
|
43
|
+
_optional_submodules = ["safety_checker"]
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
text_encoder: RBLNT5EncoderModel,
|
48
|
+
tokenizer: T5TokenizerFast,
|
49
|
+
transformer: RBLNCosmosTransformer3DModel,
|
50
|
+
vae: RBLNAutoencoderKLCosmos,
|
51
|
+
scheduler: EDMEulerScheduler,
|
52
|
+
safety_checker: RBLNCosmosSafetyChecker = None,
|
53
|
+
):
|
54
|
+
if safety_checker is None:
|
55
|
+
safety_checker = RBLNCosmosSafetyChecker()
|
56
|
+
|
57
|
+
super().__init__(
|
58
|
+
text_encoder=text_encoder,
|
59
|
+
tokenizer=tokenizer,
|
60
|
+
transformer=transformer,
|
61
|
+
vae=vae,
|
62
|
+
scheduler=scheduler,
|
63
|
+
safety_checker=safety_checker,
|
64
|
+
)
|
65
|
+
|
66
|
+
def handle_additional_kwargs(self, **kwargs):
|
67
|
+
if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
|
68
|
+
logger.warning(
|
69
|
+
f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
|
70
|
+
)
|
71
|
+
kwargs.pop("num_frames")
|
72
|
+
if (
|
73
|
+
"max_sequence_length" in kwargs
|
74
|
+
and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
|
75
|
+
):
|
76
|
+
logger.warning(
|
77
|
+
f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
|
78
|
+
)
|
79
|
+
kwargs.pop("max_sequence_length")
|
80
|
+
return kwargs
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def from_pretrained(
|
84
|
+
cls,
|
85
|
+
model_id: str,
|
86
|
+
*,
|
87
|
+
export: bool = False,
|
88
|
+
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
89
|
+
rbln_config: Dict[str, Any] = {},
|
90
|
+
**kwargs: Dict[str, Any],
|
91
|
+
):
|
92
|
+
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
93
|
+
if safety_checker is None and export:
|
94
|
+
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
95
|
+
|
96
|
+
return super().from_pretrained(
|
97
|
+
model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
|
98
|
+
)
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, Optional
|
17
|
+
|
18
|
+
from diffusers import CosmosVideoToWorldPipeline
|
19
|
+
from diffusers.schedulers import EDMEulerScheduler
|
20
|
+
from transformers import T5TokenizerFast
|
21
|
+
|
22
|
+
from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
|
23
|
+
from ....utils.logging import get_logger
|
24
|
+
from ...modeling_diffusers import RBLNDiffusionMixin
|
25
|
+
from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
|
26
|
+
from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
|
27
|
+
from .cosmos_guardrail import RBLNCosmosSafetyChecker
|
28
|
+
|
29
|
+
|
30
|
+
logger = get_logger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipeline):
|
34
|
+
"""
|
35
|
+
RBLN-accelerated implementation of Cosmos Video to World pipeline for video-to-video generation.
|
36
|
+
|
37
|
+
This pipeline compiles Cosmos Video to World models to run efficiently on RBLN NPUs, enabling high-performance
|
38
|
+
inference for generating videos with distinctive artistic style and enhanced visual quality.
|
39
|
+
"""
|
40
|
+
|
41
|
+
original_class = CosmosVideoToWorldPipeline
|
42
|
+
_submodules = ["text_encoder", "transformer", "vae"]
|
43
|
+
_optional_submodules = ["safety_checker"]
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
text_encoder: RBLNT5EncoderModel,
|
48
|
+
tokenizer: T5TokenizerFast,
|
49
|
+
transformer: RBLNCosmosTransformer3DModel,
|
50
|
+
vae: RBLNAutoencoderKLCosmos,
|
51
|
+
scheduler: EDMEulerScheduler,
|
52
|
+
safety_checker: RBLNCosmosSafetyChecker = None,
|
53
|
+
):
|
54
|
+
if safety_checker is None:
|
55
|
+
safety_checker = RBLNCosmosSafetyChecker()
|
56
|
+
|
57
|
+
super().__init__(
|
58
|
+
text_encoder=text_encoder,
|
59
|
+
tokenizer=tokenizer,
|
60
|
+
transformer=transformer,
|
61
|
+
vae=vae,
|
62
|
+
scheduler=scheduler,
|
63
|
+
safety_checker=safety_checker,
|
64
|
+
)
|
65
|
+
|
66
|
+
def handle_additional_kwargs(self, **kwargs):
|
67
|
+
if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
|
68
|
+
logger.warning(
|
69
|
+
f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
|
70
|
+
)
|
71
|
+
kwargs.pop("num_frames")
|
72
|
+
if (
|
73
|
+
"max_sequence_length" in kwargs
|
74
|
+
and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
|
75
|
+
):
|
76
|
+
logger.warning(
|
77
|
+
f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
|
78
|
+
)
|
79
|
+
kwargs.pop("max_sequence_length")
|
80
|
+
return kwargs
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def from_pretrained(
|
84
|
+
cls,
|
85
|
+
model_id: str,
|
86
|
+
*,
|
87
|
+
export: bool = False,
|
88
|
+
safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
|
89
|
+
rbln_config: Dict[str, Any] = {},
|
90
|
+
**kwargs: Dict[str, Any],
|
91
|
+
):
|
92
|
+
rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
|
93
|
+
if safety_checker is None and export:
|
94
|
+
safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
|
95
|
+
|
96
|
+
return super().from_pretrained(
|
97
|
+
model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
|
98
|
+
)
|
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
|
|
19
19
|
|
20
20
|
|
21
21
|
class RBLNKandinskyV22Pipeline(RBLNDiffusionMixin, KandinskyV22Pipeline):
|
22
|
+
"""
|
23
|
+
RBLN-accelerated implementation of Kandinsky 2.2 pipeline for text-to-image generation.
|
24
|
+
|
25
|
+
This pipeline compiles Kandinsky 2.2 models to run efficiently on RBLN NPUs, enabling high-performance
|
26
|
+
inference for generating images with distinctive artistic style and enhanced visual quality.
|
27
|
+
"""
|
28
|
+
|
22
29
|
original_class = KandinskyV22Pipeline
|
23
30
|
_rbln_config_class = RBLNKandinskyV22PipelineConfig
|
24
31
|
_submodules = ["unet", "movq"]
|