optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -12,17 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import importlib
|
16
15
|
from typing import TYPE_CHECKING, Dict, Optional, Union
|
17
16
|
|
18
17
|
import torch
|
19
18
|
from diffusers import ControlNetModel
|
20
|
-
from diffusers.models.controlnet import ControlNetOutput
|
19
|
+
from diffusers.models.controlnets.controlnet import ControlNetOutput
|
21
20
|
from transformers import PretrainedConfig
|
22
21
|
|
23
22
|
from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
24
23
|
from ...modeling import RBLNModel
|
25
24
|
from ...utils.logging import get_logger
|
25
|
+
from ...utils.model_utils import get_rbln_model_cls
|
26
26
|
from ..configurations import RBLNControlNetModelConfig
|
27
27
|
from ..modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
28
28
|
|
@@ -98,6 +98,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
|
|
98
98
|
|
99
99
|
|
100
100
|
class RBLNControlNetModel(RBLNModel):
|
101
|
+
"""
|
102
|
+
RBLN implementation of ControlNetModel for diffusion models.
|
103
|
+
|
104
|
+
This model is used to accelerate ControlNetModel models from diffusers library on RBLN NPUs.
|
105
|
+
|
106
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
107
|
+
the library implements for all its models.
|
108
|
+
"""
|
109
|
+
|
101
110
|
hf_library_name = "diffusers"
|
102
111
|
auto_model_class = ControlNetModel
|
103
112
|
output_class = ControlNetOutput
|
@@ -122,13 +131,10 @@ class RBLNControlNetModel(RBLNModel):
|
|
122
131
|
|
123
132
|
@classmethod
|
124
133
|
def update_rbln_config_using_pipe(
|
125
|
-
cls,
|
126
|
-
pipe: RBLNDiffusionMixin,
|
127
|
-
rbln_config: "RBLNDiffusionMixinConfig",
|
128
|
-
submodule_name: str,
|
134
|
+
cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
129
135
|
) -> "RBLNDiffusionMixinConfig":
|
130
|
-
rbln_vae_cls =
|
131
|
-
rbln_unet_cls =
|
136
|
+
rbln_vae_cls = get_rbln_model_cls(f"RBLN{pipe.vae.__class__.__name__}")
|
137
|
+
rbln_unet_cls = get_rbln_model_cls(f"RBLN{pipe.unet.__class__.__name__}")
|
132
138
|
|
133
139
|
rbln_config.controlnet.max_seq_len = pipe.text_encoder.config.max_position_embeddings
|
134
140
|
text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
|
@@ -56,6 +56,16 @@ class _PriorTransformer(torch.nn.Module):
|
|
56
56
|
|
57
57
|
|
58
58
|
class RBLNPriorTransformer(RBLNModel):
|
59
|
+
"""
|
60
|
+
RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
|
61
|
+
|
62
|
+
The Prior Transformer takes text and/or image embeddings from encoders (like CLIP) and
|
63
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
64
|
+
|
65
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
66
|
+
the library implements for all its models.
|
67
|
+
"""
|
68
|
+
|
59
69
|
hf_library_name = "diffusers"
|
60
70
|
auto_model_class = PriorTransformer
|
61
71
|
_output_class = PriorTransformerOutput
|
@@ -0,0 +1,321 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
17
|
+
|
18
|
+
import rebel
|
19
|
+
import torch
|
20
|
+
from diffusers import CosmosTransformer3DModel
|
21
|
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
22
|
+
from diffusers.models.transformers.transformer_cosmos import (
|
23
|
+
CosmosEmbedding,
|
24
|
+
CosmosLearnablePositionalEmbed,
|
25
|
+
CosmosPatchEmbed,
|
26
|
+
CosmosRotaryPosEmbed,
|
27
|
+
)
|
28
|
+
from torchvision import transforms
|
29
|
+
|
30
|
+
from ....configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNModelConfig
|
31
|
+
from ....modeling import RBLNModel
|
32
|
+
from ....utils.logging import get_logger
|
33
|
+
from ...configurations import RBLNCosmosTransformer3DModelConfig
|
34
|
+
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
38
|
+
|
39
|
+
from ...modeling_diffusers import RBLNCosmosTransformer3DModelConfig, RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
40
|
+
|
41
|
+
|
42
|
+
logger = get_logger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
class CosmosTransformer3DModelWrapper(torch.nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
model: CosmosTransformer3DModel,
|
49
|
+
num_latent_frames: int = 16,
|
50
|
+
latent_height: int = 88,
|
51
|
+
latent_width: int = 160,
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
self.model = model
|
55
|
+
self.num_latent_frames = num_latent_frames
|
56
|
+
self.latent_height = latent_height
|
57
|
+
self.latent_width = latent_width
|
58
|
+
self.p_t, self.p_h, self.p_w = model.config.patch_size
|
59
|
+
|
60
|
+
def forward(
|
61
|
+
self,
|
62
|
+
hidden_states: torch.Tensor,
|
63
|
+
encoder_hidden_states: torch.Tensor,
|
64
|
+
embedded_timestep: torch.Tensor,
|
65
|
+
temb: torch.Tensor,
|
66
|
+
image_rotary_emb_0: torch.Tensor,
|
67
|
+
image_rotary_emb_1: torch.Tensor,
|
68
|
+
extra_pos_emb: Optional[torch.Tensor] = None,
|
69
|
+
attention_mask: Optional[torch.Tensor] = None,
|
70
|
+
return_dict: bool = False,
|
71
|
+
):
|
72
|
+
image_rotary_emb = [image_rotary_emb_0, image_rotary_emb_1]
|
73
|
+
for block in self.model.transformer_blocks:
|
74
|
+
hidden_states = block(
|
75
|
+
hidden_states=hidden_states,
|
76
|
+
encoder_hidden_states=encoder_hidden_states,
|
77
|
+
embedded_timestep=embedded_timestep,
|
78
|
+
temb=temb,
|
79
|
+
image_rotary_emb=image_rotary_emb,
|
80
|
+
extra_pos_emb=extra_pos_emb,
|
81
|
+
attention_mask=attention_mask,
|
82
|
+
)
|
83
|
+
post_patch_num_frames = self.num_latent_frames // self.p_t
|
84
|
+
post_patch_height = self.latent_height // self.p_h
|
85
|
+
post_patch_width = self.latent_width // self.p_w
|
86
|
+
hidden_states = self.model.norm_out(hidden_states, embedded_timestep, temb)
|
87
|
+
hidden_states = self.model.proj_out(hidden_states)
|
88
|
+
hidden_states = hidden_states.unflatten(2, (self.p_h, self.p_w, self.p_t, -1))
|
89
|
+
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
|
90
|
+
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
|
91
|
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
92
|
+
|
93
|
+
return (hidden_states,)
|
94
|
+
|
95
|
+
|
96
|
+
class RBLNCosmosTransformer3DModel(RBLNModel):
|
97
|
+
"""RBLN wrapper for the Cosmos Transformer model."""
|
98
|
+
|
99
|
+
hf_library_name = "diffusers"
|
100
|
+
auto_model_class = CosmosTransformer3DModel
|
101
|
+
|
102
|
+
def __post_init__(self, **kwargs):
|
103
|
+
super().__post_init__(**kwargs)
|
104
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
105
|
+
|
106
|
+
hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
|
107
|
+
patch_embed_in_channels = (
|
108
|
+
self.config.in_channels + 1 if self.config.concat_padding_mask else self.config.in_channels
|
109
|
+
)
|
110
|
+
self.rope = CosmosRotaryPosEmbed(
|
111
|
+
hidden_size=self.config.attention_head_dim,
|
112
|
+
max_size=self.config.max_size,
|
113
|
+
patch_size=self.config.patch_size,
|
114
|
+
rope_scale=self.config.rope_scale,
|
115
|
+
)
|
116
|
+
self.rope.load_state_dict(artifacts["rope"])
|
117
|
+
if artifacts["learnable_pos_embed"] is None:
|
118
|
+
self.learnable_pos_embed = None
|
119
|
+
else:
|
120
|
+
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
|
121
|
+
hidden_size=hidden_size,
|
122
|
+
max_size=self.config.max_size,
|
123
|
+
patch_size=self.config.patch_size,
|
124
|
+
)
|
125
|
+
self.learnable_pos_embed.load_state_dict(artifacts["learnable_pos_embed"])
|
126
|
+
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, self.config.patch_size, bias=False)
|
127
|
+
self.patch_embed.load_state_dict(artifacts["patch_embed"])
|
128
|
+
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
|
129
|
+
self.time_embed.load_state_dict(artifacts["time_embed"])
|
130
|
+
|
131
|
+
def compute_embedding(
|
132
|
+
self,
|
133
|
+
hidden_states: torch.Tensor,
|
134
|
+
timestep: torch.Tensor,
|
135
|
+
attention_mask: Optional[torch.Tensor] = None,
|
136
|
+
fps: Optional[int] = None,
|
137
|
+
condition_mask: Optional[torch.Tensor] = None,
|
138
|
+
padding_mask: Optional[torch.Tensor] = None,
|
139
|
+
):
|
140
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
141
|
+
|
142
|
+
# 1. Concatenate padding mask if needed & prepare attention mask
|
143
|
+
if condition_mask is not None:
|
144
|
+
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
|
145
|
+
|
146
|
+
if self.config.concat_padding_mask:
|
147
|
+
padding_mask = transforms.functional.resize(
|
148
|
+
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
149
|
+
)
|
150
|
+
hidden_states = torch.cat(
|
151
|
+
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
|
152
|
+
)
|
153
|
+
|
154
|
+
if attention_mask is not None:
|
155
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
|
156
|
+
|
157
|
+
# 2. Generate positional embeddings
|
158
|
+
image_rotary_emb = self.rope(hidden_states, fps=fps)
|
159
|
+
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
|
160
|
+
|
161
|
+
# 3. Patchify input
|
162
|
+
p_t, p_h, p_w = self.config.patch_size
|
163
|
+
hidden_states = self.patch_embed(hidden_states)
|
164
|
+
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
|
165
|
+
|
166
|
+
# 4. Timestep embeddings
|
167
|
+
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
|
168
|
+
|
169
|
+
return (
|
170
|
+
hidden_states,
|
171
|
+
temb,
|
172
|
+
embedded_timestep,
|
173
|
+
image_rotary_emb[0],
|
174
|
+
image_rotary_emb[1],
|
175
|
+
extra_pos_emb,
|
176
|
+
attention_mask,
|
177
|
+
)
|
178
|
+
|
179
|
+
@classmethod
|
180
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
181
|
+
num_latent_frames = rbln_config.num_latent_frames
|
182
|
+
latent_height = rbln_config.latent_height
|
183
|
+
latent_width = rbln_config.latent_width
|
184
|
+
return CosmosTransformer3DModelWrapper(
|
185
|
+
model=model,
|
186
|
+
num_latent_frames=num_latent_frames,
|
187
|
+
latent_height=latent_height,
|
188
|
+
latent_width=latent_width,
|
189
|
+
).eval()
|
190
|
+
|
191
|
+
@classmethod
|
192
|
+
def update_rbln_config_using_pipe(
|
193
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
194
|
+
) -> RBLNCosmosTransformer3DModelConfig:
|
195
|
+
rbln_config.transformer.num_latent_frames = (
|
196
|
+
rbln_config.transformer.num_frames - 1
|
197
|
+
) // pipe.vae_scale_factor_temporal + 1
|
198
|
+
rbln_config.transformer.latent_height = rbln_config.transformer.height // pipe.vae_scale_factor_spatial
|
199
|
+
rbln_config.transformer.latent_width = rbln_config.transformer.width // pipe.vae_scale_factor_spatial
|
200
|
+
rbln_config.transformer.max_seq_len = pipe.text_encoder.config.n_positions
|
201
|
+
rbln_config.transformer.embedding_dim = pipe.text_encoder.encoder.embed_tokens.embedding_dim
|
202
|
+
|
203
|
+
return rbln_config
|
204
|
+
|
205
|
+
@classmethod
|
206
|
+
def save_torch_artifacts(
|
207
|
+
cls,
|
208
|
+
model: "PreTrainedModel",
|
209
|
+
save_dir_path: Path,
|
210
|
+
subfolder: str,
|
211
|
+
rbln_config: RBLNModelConfig,
|
212
|
+
):
|
213
|
+
save_dict = {}
|
214
|
+
save_dict["rope"] = model.rope.state_dict()
|
215
|
+
if model.learnable_pos_embed is not None:
|
216
|
+
save_dict["learnable_pos_embed"] = model.learnable_pos_embed.state_dict()
|
217
|
+
save_dict["patch_embed"] = model.patch_embed.state_dict()
|
218
|
+
save_dict["time_embed"] = model.time_embed.state_dict()
|
219
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
220
|
+
|
221
|
+
@classmethod
|
222
|
+
def _update_rbln_config(
|
223
|
+
cls,
|
224
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
225
|
+
model: "PreTrainedModel",
|
226
|
+
model_config: "PretrainedConfig",
|
227
|
+
rbln_config: "RBLNCosmosTransformer3DModelConfig",
|
228
|
+
) -> RBLNCosmosTransformer3DModelConfig:
|
229
|
+
p_t, p_h, p_w = model_config.patch_size
|
230
|
+
hidden_dim = (
|
231
|
+
(rbln_config.num_latent_frames // p_t)
|
232
|
+
* (rbln_config.latent_height // p_h)
|
233
|
+
* (rbln_config.latent_width // p_w)
|
234
|
+
)
|
235
|
+
attention_head_dim = model_config.attention_head_dim
|
236
|
+
hidden_size = model.config.num_attention_heads * model.config.attention_head_dim
|
237
|
+
input_info = [
|
238
|
+
(
|
239
|
+
"hidden_states",
|
240
|
+
[
|
241
|
+
rbln_config.batch_size,
|
242
|
+
hidden_dim,
|
243
|
+
hidden_size,
|
244
|
+
],
|
245
|
+
"float32",
|
246
|
+
),
|
247
|
+
(
|
248
|
+
"encoder_hidden_states",
|
249
|
+
[
|
250
|
+
rbln_config.batch_size,
|
251
|
+
rbln_config.max_seq_len,
|
252
|
+
rbln_config.embedding_dim,
|
253
|
+
],
|
254
|
+
"float32",
|
255
|
+
),
|
256
|
+
("embedded_timestep", [rbln_config.batch_size, hidden_size], "float32"),
|
257
|
+
("temb", [1, hidden_size * 3], "float32"),
|
258
|
+
("image_rotary_emb_0", [hidden_dim, attention_head_dim], "float32"),
|
259
|
+
("image_rotary_emb_1", [hidden_dim, attention_head_dim], "float32"),
|
260
|
+
("extra_pos_emb", [rbln_config.batch_size, hidden_dim, hidden_size], "float32"),
|
261
|
+
]
|
262
|
+
|
263
|
+
compile_config = RBLNCompileConfig(input_info=input_info)
|
264
|
+
rbln_config.set_compile_cfgs([compile_config])
|
265
|
+
return rbln_config
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def _create_runtimes(
|
269
|
+
cls,
|
270
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
271
|
+
rbln_config: RBLNModelConfig,
|
272
|
+
) -> List[rebel.Runtime]:
|
273
|
+
if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
|
274
|
+
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
275
|
+
|
276
|
+
return [
|
277
|
+
rebel.Runtime(
|
278
|
+
compiled_model,
|
279
|
+
tensor_type="pt",
|
280
|
+
device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
|
281
|
+
activate_profiler=rbln_config.activate_profiler,
|
282
|
+
timeout=120,
|
283
|
+
)
|
284
|
+
for compiled_model in compiled_models
|
285
|
+
]
|
286
|
+
|
287
|
+
def forward(
|
288
|
+
self,
|
289
|
+
hidden_states: torch.Tensor,
|
290
|
+
timestep: torch.Tensor,
|
291
|
+
encoder_hidden_states: torch.Tensor,
|
292
|
+
attention_mask: Optional[torch.Tensor] = None,
|
293
|
+
fps: Optional[int] = None,
|
294
|
+
condition_mask: Optional[torch.Tensor] = None,
|
295
|
+
padding_mask: Optional[torch.Tensor] = None,
|
296
|
+
return_dict: bool = True,
|
297
|
+
):
|
298
|
+
(
|
299
|
+
hidden_states,
|
300
|
+
temb,
|
301
|
+
embedded_timestep,
|
302
|
+
image_rotary_emb_0,
|
303
|
+
image_rotary_emb_1,
|
304
|
+
extra_pos_emb,
|
305
|
+
attention_mask,
|
306
|
+
) = self.compute_embedding(hidden_states, timestep, attention_mask, fps, condition_mask, padding_mask)
|
307
|
+
|
308
|
+
hidden_states = self.model[0].forward(
|
309
|
+
hidden_states,
|
310
|
+
encoder_hidden_states,
|
311
|
+
embedded_timestep,
|
312
|
+
temb,
|
313
|
+
image_rotary_emb_0,
|
314
|
+
image_rotary_emb_1,
|
315
|
+
extra_pos_emb,
|
316
|
+
)
|
317
|
+
|
318
|
+
if not return_dict:
|
319
|
+
return (hidden_states,)
|
320
|
+
else:
|
321
|
+
return Transformer2DModelOutput(sample=hidden_states)
|
@@ -59,6 +59,8 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
|
|
59
59
|
|
60
60
|
|
61
61
|
class RBLNSD3Transformer2DModel(RBLNModel):
|
62
|
+
"""RBLN wrapper for the Stable Diffusion 3 MMDiT Transformer model."""
|
63
|
+
|
62
64
|
hf_library_name = "diffusers"
|
63
65
|
auto_model_class = SD3Transformer2DModel
|
64
66
|
_output_class = Transformer2DModelOutput
|
@@ -140,6 +140,13 @@ class _UNet_Kandinsky(torch.nn.Module):
|
|
140
140
|
|
141
141
|
|
142
142
|
class RBLNUNet2DConditionModel(RBLNModel):
|
143
|
+
"""
|
144
|
+
Configuration class for RBLN UNet2DCondition models.
|
145
|
+
|
146
|
+
This class inherits from RBLNModelConfig and provides specific configuration options
|
147
|
+
for UNet2DCondition models used in diffusion-based image generation.
|
148
|
+
"""
|
149
|
+
|
143
150
|
hf_library_name = "diffusers"
|
144
151
|
auto_model_class = UNet2DConditionModel
|
145
152
|
_rbln_config_class = RBLNUNet2DConditionModelConfig
|
@@ -178,7 +185,10 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
178
185
|
rbln_config: RBLNUNet2DConditionModelConfig,
|
179
186
|
image_size: Optional[Tuple[int, int]] = None,
|
180
187
|
) -> Tuple[int, int]:
|
181
|
-
|
188
|
+
if hasattr(pipe, "movq"):
|
189
|
+
scale_factor = 2 ** (len(pipe.movq.config.block_out_channels) - 1)
|
190
|
+
else:
|
191
|
+
scale_factor = pipe.vae_scale_factor
|
182
192
|
|
183
193
|
if image_size is None:
|
184
194
|
if "Img2Img" in pipe.__class__.__name__:
|
@@ -25,6 +25,11 @@ _import_structure = {
|
|
25
25
|
"RBLNStableDiffusionXLControlNetImg2ImgPipeline",
|
26
26
|
"RBLNStableDiffusionXLControlNetPipeline",
|
27
27
|
],
|
28
|
+
"cosmos": [
|
29
|
+
"RBLNCosmosTextToWorldPipeline",
|
30
|
+
"RBLNCosmosVideoToWorldPipeline",
|
31
|
+
"RBLNCosmosSafetyChecker",
|
32
|
+
],
|
28
33
|
"kandinsky2_2": [
|
29
34
|
"RBLNKandinskyV22CombinedPipeline",
|
30
35
|
"RBLNKandinskyV22Img2ImgCombinedPipeline",
|
@@ -58,6 +63,11 @@ if TYPE_CHECKING:
|
|
58
63
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
59
64
|
RBLNStableDiffusionXLControlNetPipeline,
|
60
65
|
)
|
66
|
+
from .cosmos import (
|
67
|
+
RBLNCosmosSafetyChecker,
|
68
|
+
RBLNCosmosTextToWorldPipeline,
|
69
|
+
RBLNCosmosVideoToWorldPipeline,
|
70
|
+
)
|
61
71
|
from .kandinsky2_2 import (
|
62
72
|
RBLNKandinskyV22CombinedPipeline,
|
63
73
|
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import os
|
16
16
|
from pathlib import Path
|
17
|
-
from typing import
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
20
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
@@ -24,9 +24,6 @@ from ....utils.logging import get_logger
|
|
24
24
|
from ...models.controlnet import RBLNControlNetModel
|
25
25
|
|
26
26
|
|
27
|
-
if TYPE_CHECKING:
|
28
|
-
pass
|
29
|
-
|
30
27
|
logger = get_logger(__name__)
|
31
28
|
|
32
29
|
|
@@ -49,6 +49,13 @@ logger = get_logger(__name__)
|
|
49
49
|
|
50
50
|
|
51
51
|
class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
|
52
|
+
"""
|
53
|
+
RBLN-accelerated implementation of Stable Diffusion pipeline with ControlNet for guided text-to-image generation.
|
54
|
+
|
55
|
+
This pipeline compiles Stable Diffusion and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
|
56
|
+
inference for generating images with precise structural control using conditioning inputs like edges, depth, or poses.
|
57
|
+
"""
|
58
|
+
|
52
59
|
original_class = StableDiffusionControlNetPipeline
|
53
60
|
_rbln_config_class = RBLNStableDiffusionControlNetPipelineConfig
|
54
61
|
_submodules = ["text_encoder", "unet", "vae", "controlnet"]
|
@@ -47,6 +47,13 @@ logger = logging.get_logger(__name__)
|
|
47
47
|
|
48
48
|
|
49
49
|
class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
|
50
|
+
"""
|
51
|
+
RBLN-accelerated implementation of Stable Diffusion pipeline with ControlNet for guided image-to-image generation.
|
52
|
+
|
53
|
+
This pipeline compiles Stable Diffusion and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
|
54
|
+
inference for transforming input images with precise structural control and conditioning guidance.
|
55
|
+
"""
|
56
|
+
|
50
57
|
original_class = StableDiffusionControlNetImg2ImgPipeline
|
51
58
|
_submodules = ["text_encoder", "unet", "vae", "controlnet"]
|
52
59
|
_rbln_config_class = RBLNStableDiffusionControlNetImg2ImgPipelineConfig
|
@@ -47,6 +47,13 @@ logger = logging.get_logger(__name__)
|
|
47
47
|
|
48
48
|
|
49
49
|
class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
|
50
|
+
"""
|
51
|
+
RBLN-accelerated implementation of Stable Diffusion XL pipeline with ControlNet for high-resolution guided text-to-image generation.
|
52
|
+
|
53
|
+
This pipeline compiles Stable Diffusion XL and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
|
54
|
+
inference for generating high-quality images with precise structural control and enhanced detail preservation.
|
55
|
+
"""
|
56
|
+
|
50
57
|
original_class = StableDiffusionXLControlNetPipeline
|
51
58
|
_rbln_config_class = RBLNStableDiffusionXLControlNetPipelineConfig
|
52
59
|
_submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
|
@@ -47,6 +47,13 @@ logger = logging.get_logger(__name__)
|
|
47
47
|
|
48
48
|
|
49
49
|
class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetImg2ImgPipeline):
|
50
|
+
"""
|
51
|
+
RBLN-accelerated implementation of Stable Diffusion XL pipeline with ControlNet for high-resolution guided image-to-image generation.
|
52
|
+
|
53
|
+
This pipeline compiles Stable Diffusion XL and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
|
54
|
+
inference for transforming input images with precise structural control and enhanced quality preservation.
|
55
|
+
"""
|
56
|
+
|
50
57
|
original_class = StableDiffusionXLControlNetImg2ImgPipeline
|
51
58
|
_rbln_config_class = RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig
|
52
59
|
_submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
|
@@ -0,0 +1,17 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .cosmos_guardrail import RBLNCosmosSafetyChecker
|
16
|
+
from .pipeline_cosmos_text2world import RBLNCosmosTextToWorldPipeline
|
17
|
+
from .pipeline_cosmos_video2world import RBLNCosmosVideoToWorldPipeline
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional, Tuple
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
|
18
|
+
from ....transformers import RBLNSiglipVisionModelConfig
|
19
|
+
|
20
|
+
|
21
|
+
class RBLNVideoSafetyModelConfig(RBLNModelConfig):
|
22
|
+
"""
|
23
|
+
Configuration class for RBLN Video Content Safety Filter.
|
24
|
+
"""
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
batch_size: Optional[int] = None,
|
29
|
+
input_size: Optional[int] = None,
|
30
|
+
image_size: Optional[Tuple[int, int]] = None,
|
31
|
+
**kwargs,
|
32
|
+
):
|
33
|
+
super().__init__(**kwargs)
|
34
|
+
self.batch_size = batch_size or 1
|
35
|
+
self.input_size = input_size or 1152
|
36
|
+
|
37
|
+
|
38
|
+
class RBLNRetinaFaceFilterConfig(RBLNModelConfig):
|
39
|
+
"""
|
40
|
+
Configuration class for RBLN Retina Face Filter.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
batch_size: Optional[int] = None,
|
46
|
+
image_size: Optional[Tuple[int, int]] = None,
|
47
|
+
**kwargs,
|
48
|
+
):
|
49
|
+
super().__init__(**kwargs)
|
50
|
+
self.batch_size = batch_size or 1
|
51
|
+
self.image_size = image_size or (704, 1280)
|
52
|
+
|
53
|
+
|
54
|
+
class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
|
55
|
+
"""
|
56
|
+
Configuration class for RBLN Cosmos Safety Checker.
|
57
|
+
"""
|
58
|
+
|
59
|
+
submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
aegis: Optional[RBLNModelConfig] = None,
|
64
|
+
video_safety_model: Optional[RBLNModelConfig] = None,
|
65
|
+
face_blur_filter: Optional[RBLNModelConfig] = None,
|
66
|
+
siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
|
67
|
+
*,
|
68
|
+
batch_size: Optional[int] = None,
|
69
|
+
image_size: Optional[Tuple[int, int]] = None,
|
70
|
+
height: Optional[int] = None,
|
71
|
+
width: Optional[int] = None,
|
72
|
+
**kwargs: Dict[str, Any],
|
73
|
+
):
|
74
|
+
super().__init__(**kwargs)
|
75
|
+
if height is not None and width is not None:
|
76
|
+
image_size = (height, width)
|
77
|
+
|
78
|
+
self.aegis = self.init_submodule_config(RBLNModelConfig, aegis)
|
79
|
+
self.siglip_encoder = self.init_submodule_config(
|
80
|
+
RBLNSiglipVisionModelConfig,
|
81
|
+
siglip_encoder,
|
82
|
+
batch_size=batch_size,
|
83
|
+
image_size=(384, 384),
|
84
|
+
)
|
85
|
+
|
86
|
+
self.video_safety_model = self.init_submodule_config(
|
87
|
+
RBLNVideoSafetyModelConfig,
|
88
|
+
video_safety_model,
|
89
|
+
batch_size=batch_size,
|
90
|
+
input_size=1152,
|
91
|
+
)
|
92
|
+
self.face_blur_filter = self.init_submodule_config(
|
93
|
+
RBLNRetinaFaceFilterConfig,
|
94
|
+
face_blur_filter,
|
95
|
+
batch_size=batch_size,
|
96
|
+
image_size=image_size,
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
RBLNAutoConfig.register(RBLNVideoSafetyModelConfig)
|
101
|
+
RBLNAutoConfig.register(RBLNRetinaFaceFilterConfig)
|
102
|
+
RBLNAutoConfig.register(RBLNCosmosSafetyCheckerConfig)
|