optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import rebel
|
|
18
18
|
import torch
|
|
@@ -214,13 +214,41 @@ class RBLNAutoencoderKL(RBLNModel):
|
|
|
214
214
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
215
215
|
]
|
|
216
216
|
|
|
217
|
-
def encode(
|
|
217
|
+
def encode(
|
|
218
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
|
|
219
|
+
) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
|
|
220
|
+
"""
|
|
221
|
+
Encode an input image into a latent representation.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
x: The input image to encode.
|
|
225
|
+
return_dict:
|
|
226
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
227
|
+
kwargs: Additional arguments to pass to the encoder.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
231
|
+
"""
|
|
218
232
|
posterior = self.encoder.encode(x)
|
|
219
233
|
if not return_dict:
|
|
220
234
|
return (posterior,)
|
|
221
235
|
return AutoencoderKLOutput(latent_dist=posterior)
|
|
222
236
|
|
|
223
|
-
def decode(
|
|
237
|
+
def decode(
|
|
238
|
+
self, z: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
|
|
239
|
+
) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
240
|
+
"""
|
|
241
|
+
Decode a latent representation into an image.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
z: The latent representation to decode.
|
|
245
|
+
return_dict:
|
|
246
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
247
|
+
kwargs: Additional arguments to pass to the decoder.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
The decoded image or DecoderOutput if return_dict=True
|
|
251
|
+
"""
|
|
224
252
|
dec = self.decoder.decode(z)
|
|
225
253
|
if not return_dict:
|
|
226
254
|
return (dec,)
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Dict, List, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
|
16
16
|
|
|
17
17
|
import rebel
|
|
18
18
|
import torch
|
|
@@ -205,13 +205,38 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
|
|
|
205
205
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
206
206
|
]
|
|
207
207
|
|
|
208
|
-
def encode(
|
|
208
|
+
def encode(
|
|
209
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
|
|
210
|
+
) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
|
|
211
|
+
"""
|
|
212
|
+
Encode an input video into a latent representation.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
x: The input video to encode.
|
|
216
|
+
return_dict:
|
|
217
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
218
|
+
kwargs: Additional arguments to pass to the encoder.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
The latent representation or AutoencoderKLOutput if return_dict=True
|
|
222
|
+
"""
|
|
209
223
|
posterior = self.encoder.encode(x)
|
|
210
224
|
if not return_dict:
|
|
211
225
|
return (posterior,)
|
|
212
226
|
return AutoencoderKLOutput(latent_dist=posterior)
|
|
213
227
|
|
|
214
|
-
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
|
|
228
|
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
229
|
+
"""
|
|
230
|
+
Decode a latent representation into a video.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
z: The latent representation to decode.
|
|
234
|
+
return_dict:
|
|
235
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
The decoded video or DecoderOutput if return_dict=True
|
|
239
|
+
"""
|
|
215
240
|
decoded = self.decoder.decode(z)
|
|
216
241
|
|
|
217
242
|
if not return_dict:
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, List, Union
|
|
15
|
+
from typing import TYPE_CHECKING, Any, List, Union
|
|
16
16
|
|
|
17
17
|
import rebel
|
|
18
18
|
import torch
|
|
@@ -170,13 +170,41 @@ class RBLNVQModel(RBLNModel):
|
|
|
170
170
|
for compiled_model, device_val in zip(compiled_models, device_vals)
|
|
171
171
|
]
|
|
172
172
|
|
|
173
|
-
def encode(
|
|
173
|
+
def encode(
|
|
174
|
+
self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
|
|
175
|
+
) -> Union[torch.FloatTensor, VQEncoderOutput]:
|
|
176
|
+
"""
|
|
177
|
+
Encode an input image into a quantized latent representation.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
x: The input image to encode.
|
|
181
|
+
return_dict:
|
|
182
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
183
|
+
kwargs: Additional arguments to pass to the encoder/quantizer.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
The quantized latent representation or a specific output object.
|
|
187
|
+
"""
|
|
174
188
|
posterior = self.encoder.encode(x)
|
|
175
189
|
if not return_dict:
|
|
176
190
|
return (posterior,)
|
|
177
191
|
return VQEncoderOutput(latents=posterior)
|
|
178
192
|
|
|
179
|
-
def decode(
|
|
193
|
+
def decode(
|
|
194
|
+
self, h: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
|
|
195
|
+
) -> Union[torch.FloatTensor, DecoderOutput]:
|
|
196
|
+
"""
|
|
197
|
+
Decode a quantized latent representation back into an image.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
h: The quantized latent representation to decode.
|
|
201
|
+
return_dict:
|
|
202
|
+
Whether to return output as a dictionary. Defaults to True.
|
|
203
|
+
kwargs: Additional arguments to pass to the decoder.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The decoded image or a DecoderOutput object.
|
|
207
|
+
"""
|
|
180
208
|
dec, commit_loss = self.decoder.decode(h, **kwargs)
|
|
181
209
|
if not return_dict:
|
|
182
210
|
return (dec, commit_loss)
|
|
@@ -59,7 +59,7 @@ class RBLNPriorTransformer(RBLNModel):
|
|
|
59
59
|
"""
|
|
60
60
|
RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
|
|
61
61
|
|
|
62
|
-
The
|
|
62
|
+
The PriorTransformer takes text and/or image embeddings from encoders (like CLIP) and
|
|
63
63
|
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
64
64
|
|
|
65
65
|
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
@@ -94,7 +94,15 @@ class CosmosTransformer3DModelWrapper(torch.nn.Module):
|
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
class RBLNCosmosTransformer3DModel(RBLNModel):
|
|
97
|
-
"""
|
|
97
|
+
"""
|
|
98
|
+
RBLN implementation of CosmosTransformer3DModel for diffusion models like Cosmos.
|
|
99
|
+
|
|
100
|
+
The CosmosTransformer3DModel takes text and/or image embeddings from encoders (like CLIP) and
|
|
101
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
102
|
+
|
|
103
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
104
|
+
the library implements for all its models.
|
|
105
|
+
"""
|
|
98
106
|
|
|
99
107
|
hf_library_name = "diffusers"
|
|
100
108
|
auto_model_class = CosmosTransformer3DModel
|
|
@@ -59,7 +59,15 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
class RBLNSD3Transformer2DModel(RBLNModel):
|
|
62
|
-
"""
|
|
62
|
+
"""
|
|
63
|
+
RBLN implementation of SD3Transformer2DModel for diffusion models like Stable Diffusion 3.
|
|
64
|
+
|
|
65
|
+
The SD3Transformer2DModel takes text and/or image embeddings from encoders (like CLIP) and
|
|
66
|
+
maps them to a shared latent space that guides the diffusion process to generate the desired image.
|
|
67
|
+
|
|
68
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
69
|
+
the library implements for all its models.
|
|
70
|
+
"""
|
|
63
71
|
|
|
64
72
|
hf_library_name = "diffusers"
|
|
65
73
|
auto_model_class = SD3Transformer2DModel
|
|
@@ -141,10 +141,13 @@ class _UNet_Kandinsky(torch.nn.Module):
|
|
|
141
141
|
|
|
142
142
|
class RBLNUNet2DConditionModel(RBLNModel):
|
|
143
143
|
"""
|
|
144
|
-
|
|
144
|
+
RBLN implementation of UNet2DConditionModel for diffusion models.
|
|
145
145
|
|
|
146
|
-
This
|
|
147
|
-
|
|
146
|
+
This model is used to accelerate UNet2DCondition models from diffusers library on RBLN NPUs.
|
|
147
|
+
It is a key component in diffusion-based image generation models like Stable Diffusion.
|
|
148
|
+
|
|
149
|
+
This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
|
|
150
|
+
the library implements for all its models.
|
|
148
151
|
"""
|
|
149
152
|
|
|
150
153
|
hf_library_name = "diffusers"
|
|
@@ -18,6 +18,11 @@ from transformers.utils import _LazyModule
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
_import_structure = {
|
|
21
|
+
"auto_pipeline": [
|
|
22
|
+
"RBLNAutoPipelineForImage2Image",
|
|
23
|
+
"RBLNAutoPipelineForInpainting",
|
|
24
|
+
"RBLNAutoPipelineForText2Image",
|
|
25
|
+
],
|
|
21
26
|
"controlnet": [
|
|
22
27
|
"RBLNMultiControlNetModel",
|
|
23
28
|
"RBLNStableDiffusionControlNetImg2ImgPipeline",
|
|
@@ -56,6 +61,11 @@ _import_structure = {
|
|
|
56
61
|
],
|
|
57
62
|
}
|
|
58
63
|
if TYPE_CHECKING:
|
|
64
|
+
from .auto_pipeline import (
|
|
65
|
+
RBLNAutoPipelineForImage2Image,
|
|
66
|
+
RBLNAutoPipelineForInpainting,
|
|
67
|
+
RBLNAutoPipelineForText2Image,
|
|
68
|
+
)
|
|
59
69
|
from .controlnet import (
|
|
60
70
|
RBLNMultiControlNetModel,
|
|
61
71
|
RBLNStableDiffusionControlNetImg2ImgPipeline,
|
|
@@ -63,11 +73,7 @@ if TYPE_CHECKING:
|
|
|
63
73
|
RBLNStableDiffusionXLControlNetImg2ImgPipeline,
|
|
64
74
|
RBLNStableDiffusionXLControlNetPipeline,
|
|
65
75
|
)
|
|
66
|
-
from .cosmos import
|
|
67
|
-
RBLNCosmosSafetyChecker,
|
|
68
|
-
RBLNCosmosTextToWorldPipeline,
|
|
69
|
-
RBLNCosmosVideoToWorldPipeline,
|
|
70
|
-
)
|
|
76
|
+
from .cosmos import RBLNCosmosSafetyChecker, RBLNCosmosTextToWorldPipeline, RBLNCosmosVideoToWorldPipeline
|
|
71
77
|
from .kandinsky2_2 import (
|
|
72
78
|
RBLNKandinskyV22CombinedPipeline,
|
|
73
79
|
RBLNKandinskyV22Img2ImgCombinedPipeline,
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import importlib
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, Type, Union
|
|
19
|
+
|
|
20
|
+
from diffusers.models.controlnets import ControlNetUnionModel
|
|
21
|
+
from diffusers.pipelines.auto_pipeline import (
|
|
22
|
+
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
|
23
|
+
AUTO_INPAINT_PIPELINES_MAPPING,
|
|
24
|
+
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
|
25
|
+
AutoPipelineForImage2Image,
|
|
26
|
+
AutoPipelineForInpainting,
|
|
27
|
+
AutoPipelineForText2Image,
|
|
28
|
+
_get_task_class,
|
|
29
|
+
)
|
|
30
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
|
31
|
+
|
|
32
|
+
from optimum.rbln.configuration_utils import RBLNModelConfig
|
|
33
|
+
from optimum.rbln.modeling_base import RBLNBaseModel
|
|
34
|
+
from optimum.rbln.utils.model_utils import (
|
|
35
|
+
MODEL_MAPPING,
|
|
36
|
+
convert_hf_to_rbln_model_name,
|
|
37
|
+
convert_rbln_to_hf_model_name,
|
|
38
|
+
get_rbln_model_cls,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RBLNAutoPipelineBase:
|
|
43
|
+
_model_mapping = None
|
|
44
|
+
_model_mapping_names = None
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def get_rbln_cls(cls, pretrained_model_name_or_path: Union[str, Path], export: bool = None, **kwargs):
|
|
48
|
+
if isinstance(pretrained_model_name_or_path, Path):
|
|
49
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
|
|
50
|
+
|
|
51
|
+
if export is None:
|
|
52
|
+
export = not cls._is_compiled_pipeline(pretrained_model_name_or_path, **kwargs)
|
|
53
|
+
|
|
54
|
+
if export:
|
|
55
|
+
hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
|
|
56
|
+
rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
|
|
57
|
+
else:
|
|
58
|
+
rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
|
|
59
|
+
if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
|
|
62
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model, "
|
|
63
|
+
f"or directly use '{rbln_class_name}.from_pretrained()`."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
rbln_cls = get_rbln_model_cls(rbln_class_name)
|
|
68
|
+
except AttributeError as e:
|
|
69
|
+
raise AttributeError(
|
|
70
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
|
|
71
|
+
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
|
72
|
+
) from e
|
|
73
|
+
|
|
74
|
+
return rbln_cls
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
|
|
78
|
+
"""
|
|
79
|
+
Retrieve the path to the compiled model directory for a given RBLN model.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
pretrained_model_name_or_path (str): Identifier of the model.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
str: Path to the compiled model directory.
|
|
86
|
+
"""
|
|
87
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path)
|
|
88
|
+
|
|
89
|
+
if "_class_name" not in model_index_config:
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
|
|
92
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return model_index_config["_class_name"]
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _is_compiled_pipeline(
|
|
99
|
+
cls,
|
|
100
|
+
pretrained_model_name_or_path: Union[str, Path],
|
|
101
|
+
cache_dir=None,
|
|
102
|
+
force_download=False,
|
|
103
|
+
proxies=None,
|
|
104
|
+
token=None,
|
|
105
|
+
local_files_only=False,
|
|
106
|
+
revision=None,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
config: dict = cls.load_config(
|
|
110
|
+
pretrained_model_name_or_path,
|
|
111
|
+
cache_dir=cache_dir,
|
|
112
|
+
force_download=force_download,
|
|
113
|
+
proxies=proxies,
|
|
114
|
+
token=token,
|
|
115
|
+
local_files_only=local_files_only,
|
|
116
|
+
revision=revision,
|
|
117
|
+
)
|
|
118
|
+
for value in config.values():
|
|
119
|
+
if isinstance(value, list) and len(value) > 0 and value[0] == "optimum.rbln":
|
|
120
|
+
return True
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def infer_hf_model_class(
|
|
125
|
+
cls,
|
|
126
|
+
pretrained_model_or_path: Union[str, Path],
|
|
127
|
+
cache_dir=None,
|
|
128
|
+
force_download=False,
|
|
129
|
+
proxies=None,
|
|
130
|
+
token=None,
|
|
131
|
+
local_files_only=False,
|
|
132
|
+
revision=None,
|
|
133
|
+
**kwargs,
|
|
134
|
+
):
|
|
135
|
+
config = cls.load_config(
|
|
136
|
+
pretrained_model_or_path,
|
|
137
|
+
cache_dir=cache_dir,
|
|
138
|
+
force_download=force_download,
|
|
139
|
+
proxies=proxies,
|
|
140
|
+
token=token,
|
|
141
|
+
local_files_only=local_files_only,
|
|
142
|
+
revision=revision,
|
|
143
|
+
)
|
|
144
|
+
pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
|
|
145
|
+
|
|
146
|
+
pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
|
|
147
|
+
|
|
148
|
+
return pipeline_cls
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
152
|
+
orig_class_name = config["_class_name"]
|
|
153
|
+
if "ControlPipeline" in orig_class_name:
|
|
154
|
+
to_replace = "ControlPipeline"
|
|
155
|
+
else:
|
|
156
|
+
to_replace = "Pipeline"
|
|
157
|
+
|
|
158
|
+
if "controlnet" in kwargs:
|
|
159
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
160
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
|
|
161
|
+
else:
|
|
162
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
|
|
163
|
+
if "enable_pag" in kwargs:
|
|
164
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
165
|
+
if enable_pag:
|
|
166
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
|
|
167
|
+
|
|
168
|
+
return orig_class_name
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
@validate_hf_hub_args
|
|
172
|
+
def from_pretrained(
|
|
173
|
+
cls,
|
|
174
|
+
model_id: Union[str, Path],
|
|
175
|
+
*,
|
|
176
|
+
export: bool = None,
|
|
177
|
+
rbln_config: Union[Dict[str, Any], RBLNModelConfig] = {},
|
|
178
|
+
**kwargs: Any,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Load an RBLN-accelerated Diffusers pipeline from a pretrained checkpoint or a compiled RBLN artifact.
|
|
182
|
+
|
|
183
|
+
This method determines the concrete `RBLN*` model class that corresponds to the
|
|
184
|
+
underlying Diffusers pipeline architecture and dispatches to that class's
|
|
185
|
+
`from_pretrained()` implementation. If a compiled RBLN folder is detected at `model_id`
|
|
186
|
+
(or `export=False` is explicitly passed), it loads the compiled artifacts; otherwise it
|
|
187
|
+
compiles from the original Diffusers checkpoint.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
model_id:
|
|
191
|
+
HF repo id or local path. For compiled models, this should point to a directory
|
|
192
|
+
(optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
|
|
193
|
+
export:
|
|
194
|
+
Force compilation from a Diffusers checkpoint. When `None`, this is inferred by
|
|
195
|
+
checking whether compiled artifacts exist at `model_id`.
|
|
196
|
+
rbln_config:
|
|
197
|
+
RBLN compilation/runtime configuration. May be provided as a dictionary or as an
|
|
198
|
+
instance of the specific model's config class (e.g., `RBLNFluxPipelineConfig`).
|
|
199
|
+
kwargs: Additional keyword arguments.
|
|
200
|
+
- Arguments prefixed with `rbln_` are forwarded to the RBLN config.
|
|
201
|
+
- Remaining arguments are forwarded to the Diffusers loader.
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
RBLNBaseModel: An instantiated RBLN model wrapping the Diffusers pipeline, ready for
|
|
205
|
+
inference on RBLN NPUs.
|
|
206
|
+
|
|
207
|
+
"""
|
|
208
|
+
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
209
|
+
return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
|
|
213
|
+
"""
|
|
214
|
+
Register a new RBLN model class.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
|
|
218
|
+
exist_ok (bool): Whether to allow registering an already registered model.
|
|
219
|
+
"""
|
|
220
|
+
if not issubclass(rbln_cls, RBLNBaseModel):
|
|
221
|
+
raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
|
|
222
|
+
|
|
223
|
+
native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
|
|
224
|
+
if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
|
|
225
|
+
if not exist_ok:
|
|
226
|
+
raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
|
|
227
|
+
|
|
228
|
+
MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
|
|
232
|
+
"""Text2Image AutoPipeline for RBLN NPUs."""
|
|
233
|
+
|
|
234
|
+
_model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
|
|
235
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
|
|
239
|
+
"""Image2Image AutoPipeline for RBLN NPUs."""
|
|
240
|
+
|
|
241
|
+
_model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
|
|
242
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
|
|
243
|
+
|
|
244
|
+
@classmethod
|
|
245
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
246
|
+
orig_class_name = config["_class_name"]
|
|
247
|
+
# the `orig_class_name` can be:
|
|
248
|
+
# `- *Pipeline` (for regular text-to-image checkpoint)
|
|
249
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
|
250
|
+
# `- *Img2ImgPipeline` (for refiner checkpoint)
|
|
251
|
+
if "Img2Img" in orig_class_name:
|
|
252
|
+
to_replace = "Img2ImgPipeline"
|
|
253
|
+
elif "ControlPipeline" in orig_class_name:
|
|
254
|
+
to_replace = "ControlPipeline"
|
|
255
|
+
else:
|
|
256
|
+
to_replace = "Pipeline"
|
|
257
|
+
|
|
258
|
+
if "controlnet" in kwargs:
|
|
259
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
260
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
|
261
|
+
else:
|
|
262
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
|
263
|
+
if "enable_pag" in kwargs:
|
|
264
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
265
|
+
if enable_pag:
|
|
266
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
|
267
|
+
|
|
268
|
+
if to_replace == "ControlPipeline":
|
|
269
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
|
|
270
|
+
|
|
271
|
+
return orig_class_name
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
|
|
275
|
+
"""Inpainting AutoPipeline for RBLN NPUs."""
|
|
276
|
+
|
|
277
|
+
_model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
|
|
278
|
+
_model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def get_pipeline_key_name(cls, config, **kwargs):
|
|
282
|
+
orig_class_name = config["_class_name"]
|
|
283
|
+
|
|
284
|
+
# The `orig_class_name`` can be:
|
|
285
|
+
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
|
|
286
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
|
287
|
+
# - or *Pipeline (for regular text-to-image checkpoint)
|
|
288
|
+
if "Inpaint" in orig_class_name:
|
|
289
|
+
to_replace = "InpaintPipeline"
|
|
290
|
+
elif "ControlPipeline" in orig_class_name:
|
|
291
|
+
to_replace = "ControlPipeline"
|
|
292
|
+
else:
|
|
293
|
+
to_replace = "Pipeline"
|
|
294
|
+
|
|
295
|
+
if "controlnet" in kwargs:
|
|
296
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
|
297
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
|
298
|
+
else:
|
|
299
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
|
300
|
+
if "enable_pag" in kwargs:
|
|
301
|
+
enable_pag = kwargs.pop("enable_pag")
|
|
302
|
+
if enable_pag:
|
|
303
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
|
304
|
+
if to_replace == "ControlPipeline":
|
|
305
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
|
|
306
|
+
|
|
307
|
+
return orig_class_name
|
|
@@ -12,10 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
|
|
18
|
-
from ....transformers import
|
|
18
|
+
from ....transformers import RBLNSiglipVisionModelConfig
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class RBLNVideoSafetyModelConfig(RBLNModelConfig):
|
|
@@ -56,11 +56,11 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
|
|
|
56
56
|
Configuration class for RBLN Cosmos Safety Checker.
|
|
57
57
|
"""
|
|
58
58
|
|
|
59
|
-
submodules = ["
|
|
59
|
+
submodules = ["llamaguard3", "video_safety_model", "face_blur_filter", "siglip_encoder"]
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
62
62
|
self,
|
|
63
|
-
|
|
63
|
+
llamaguard3: Optional[RBLNModelConfig] = None,
|
|
64
64
|
video_safety_model: Optional[RBLNModelConfig] = None,
|
|
65
65
|
face_blur_filter: Optional[RBLNModelConfig] = None,
|
|
66
66
|
siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
|
|
@@ -69,37 +69,40 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
|
|
|
69
69
|
image_size: Optional[Tuple[int, int]] = None,
|
|
70
70
|
height: Optional[int] = None,
|
|
71
71
|
width: Optional[int] = None,
|
|
72
|
-
|
|
72
|
+
max_seq_len: Optional[int] = None,
|
|
73
|
+
**kwargs: Any,
|
|
73
74
|
):
|
|
74
75
|
super().__init__(**kwargs)
|
|
75
76
|
if height is not None and width is not None:
|
|
76
77
|
image_size = (height, width)
|
|
77
78
|
|
|
79
|
+
if max_seq_len is None:
|
|
80
|
+
max_seq_len = 512
|
|
81
|
+
|
|
78
82
|
tensor_parallel_size = kwargs.get("tensor_parallel_size")
|
|
79
83
|
|
|
80
|
-
self.
|
|
81
|
-
|
|
82
|
-
|
|
84
|
+
self.llamaguard3 = self.initialize_submodule_config(
|
|
85
|
+
llamaguard3,
|
|
86
|
+
cls_name="RBLNLlamaForCausalLMConfig",
|
|
83
87
|
batch_size=batch_size,
|
|
84
88
|
tensor_parallel_size=tensor_parallel_size,
|
|
89
|
+
max_seq_len=max_seq_len,
|
|
85
90
|
)
|
|
86
|
-
|
|
87
|
-
self.siglip_encoder = self.init_submodule_config(
|
|
88
|
-
RBLNSiglipVisionModelConfig,
|
|
91
|
+
self.siglip_encoder = self.initialize_submodule_config(
|
|
89
92
|
siglip_encoder,
|
|
93
|
+
cls_name="RBLNSiglipVisionModelConfig",
|
|
90
94
|
batch_size=batch_size,
|
|
91
95
|
image_size=(384, 384),
|
|
92
96
|
)
|
|
93
|
-
|
|
94
|
-
self.video_safety_model = self.init_submodule_config(
|
|
95
|
-
RBLNVideoSafetyModelConfig,
|
|
97
|
+
self.video_safety_model = self.initialize_submodule_config(
|
|
96
98
|
video_safety_model,
|
|
99
|
+
cls_name="RBLNVideoSafetyModelConfig",
|
|
97
100
|
batch_size=batch_size,
|
|
98
101
|
input_size=1152,
|
|
99
102
|
)
|
|
100
|
-
self.face_blur_filter = self.
|
|
101
|
-
RBLNRetinaFaceFilterConfig,
|
|
103
|
+
self.face_blur_filter = self.initialize_submodule_config(
|
|
102
104
|
face_blur_filter,
|
|
105
|
+
cls_name="RBLNRetinaFaceFilterConfig",
|
|
103
106
|
batch_size=batch_size,
|
|
104
107
|
image_size=image_size,
|
|
105
108
|
)
|