optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,455 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import pathlib
17
+ from functools import partial
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+ from unittest.mock import patch
20
+
21
+ import rebel
22
+ import torch
23
+ from diffusers.utils import is_cosmos_guardrail_available
24
+ from huggingface_hub import snapshot_download
25
+ from transformers import AutoTokenizer, SiglipProcessor
26
+
27
+ from .... import RBLNAutoModelForCausalLM, RBLNSiglipVisionModel
28
+ from ....utils.runtime_utils import RBLNPytorchRuntime, UnavailableRuntime
29
+ from .configuration_cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
30
+
31
+
32
+ if is_cosmos_guardrail_available():
33
+ from cosmos_guardrail import CosmosSafetyChecker
34
+ from cosmos_guardrail.cosmos_guardrail import (
35
+ COSMOS_GUARDRAIL_CHECKPOINT,
36
+ Aegis,
37
+ Blocklist,
38
+ GuardrailRunner,
39
+ ModelConfig,
40
+ RetinaFaceFilter,
41
+ SafetyClassifier,
42
+ SigLIPEncoder,
43
+ VideoContentSafetyFilter,
44
+ VideoSafetyModel,
45
+ )
46
+ from retinaface.data import cfg_re50
47
+
48
+ COSMOS_AVAILABLE = True
49
+ else:
50
+ COSMOS_AVAILABLE = False
51
+
52
+ class FailToImportCosmosGuardrail(torch.nn.Module): ...
53
+
54
+ class CosmosSafetyChecker(FailToImportCosmosGuardrail): ...
55
+
56
+ COSMOS_GUARDRAIL_CHECKPOINT = None
57
+
58
+ class Aegis(FailToImportCosmosGuardrail): ...
59
+
60
+ class Blocklist(FailToImportCosmosGuardrail): ...
61
+
62
+ class GuardrailRunner(FailToImportCosmosGuardrail): ...
63
+
64
+ class ModelConfig(FailToImportCosmosGuardrail): ...
65
+
66
+ class RetinaFaceFilter(FailToImportCosmosGuardrail): ...
67
+
68
+ class SafetyClassifier(FailToImportCosmosGuardrail): ...
69
+
70
+ class SigLIPEncoder(FailToImportCosmosGuardrail): ...
71
+
72
+ class VideoContentSafetyFilter(FailToImportCosmosGuardrail): ...
73
+
74
+ class VideoSafetyModel(FailToImportCosmosGuardrail): ...
75
+
76
+ cfg_re50 = None
77
+
78
+
79
+ def is_compiled_dir(dir: str) -> bool:
80
+ # walk directory and check if there is any *.rbln files in that dir.
81
+ if not os.path.exists(dir):
82
+ return False
83
+
84
+ for root, dirs, files in os.walk(dir):
85
+ for file in files:
86
+ if file.endswith(".rbln"):
87
+ return True
88
+ return False
89
+
90
+
91
+ def get_image_features(
92
+ self,
93
+ pixel_values: torch.Tensor,
94
+ return_dict: bool = True,
95
+ output_attentions: bool = False,
96
+ output_hidden_states: bool = False,
97
+ interpolate_pos_encoding: bool = False,
98
+ ):
99
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
100
+ output_hidden_states = (
101
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
102
+ )
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+
105
+ return self(
106
+ pixel_values,
107
+ return_dict=return_dict,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ interpolate_pos_encoding=interpolate_pos_encoding,
111
+ )[1]
112
+
113
+
114
+ class RBLNSigLIPEncoder(SigLIPEncoder):
115
+ def __init__(
116
+ self,
117
+ model_name: str = "google/siglip-so400m-patch14-384",
118
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
119
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
120
+ ):
121
+ torch.nn.Module.__init__(self)
122
+ if is_compiled_dir(checkpoint_id):
123
+ self.checkpoint_dir = (
124
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder"
125
+ ).as_posix()
126
+ self.processor = SiglipProcessor.from_pretrained(self.checkpoint_dir)
127
+
128
+ # We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
129
+ self.model = RBLNSiglipVisionModel.from_pretrained(
130
+ self.checkpoint_dir,
131
+ rbln_device=rbln_config.siglip_encoder.device,
132
+ rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
133
+ rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
134
+ rbln_optimize_host_memory=rbln_config.siglip_encoder.optimize_host_memory,
135
+ )
136
+ else:
137
+ super().__init__(model_name, checkpoint_id)
138
+ model = self.model
139
+ del self.model
140
+ self.model = RBLNSiglipVisionModel.from_model(
141
+ model,
142
+ rbln_device=rbln_config.siglip_encoder.device,
143
+ rbln_image_size=rbln_config.siglip_encoder.image_size,
144
+ rbln_npu=rbln_config.siglip_encoder.npu,
145
+ rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
146
+ rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
147
+ rbln_optimize_host_memory=rbln_config.siglip_encoder.optimize_host_memory,
148
+ )
149
+ self.rbln_config = rbln_config
150
+
151
+ # Override get_image_features to return pooler_output
152
+ self.model.get_image_features = lambda *args, **kwargs: get_image_features(self.model, *args, **kwargs)
153
+
154
+ def save_pretrained(self, checkpoint_id: str):
155
+ cache_dir = (pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder").as_posix()
156
+ self.model.save_pretrained(cache_dir)
157
+ self.processor.save_pretrained(cache_dir)
158
+
159
+
160
+ class RBLNRetinaFaceFilter(RetinaFaceFilter):
161
+ def __init__(
162
+ self,
163
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
164
+ batch_size: int = 1,
165
+ confidence_threshold: float = 0.7,
166
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
167
+ ):
168
+ torch.nn.Module.__init__(self)
169
+ if is_compiled_dir(checkpoint_id):
170
+ self.compiled_model = rebel.RBLNCompiledModel(
171
+ pathlib.Path(checkpoint_id) / "face_blur_filter" / "retinaface.rbln"
172
+ )
173
+ self.cfg = cfg_re50
174
+ self.batch_size = batch_size
175
+ self.confidence_threshold = confidence_threshold
176
+ self.cfg["pretrain"] = False
177
+ else:
178
+ with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
179
+ super().__init__(checkpoint_id)
180
+ net = self.net
181
+ del self.net
182
+ self.compiled_model = rebel.compile_from_torch(
183
+ net,
184
+ input_info=[
185
+ (
186
+ "frames",
187
+ [
188
+ self.batch_size,
189
+ 3,
190
+ rbln_config.face_blur_filter.image_size[0],
191
+ rbln_config.face_blur_filter.image_size[1],
192
+ ],
193
+ "float32",
194
+ )
195
+ ],
196
+ npu=rbln_config.face_blur_filter.npu,
197
+ )
198
+
199
+ self.rbln_config = rbln_config
200
+
201
+ try:
202
+ runtime = (
203
+ rebel.Runtime(
204
+ self.compiled_model,
205
+ tensor_type="pt",
206
+ device=self.rbln_config.face_blur_filter.device,
207
+ activate_profiler=rbln_config.face_blur_filter.activate_profiler,
208
+ )
209
+ if self.rbln_config.face_blur_filter.create_runtimes
210
+ else UnavailableRuntime()
211
+ )
212
+ except rebel.core.exception.RBLNRuntimeError as e:
213
+ error_msg = (
214
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
215
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
216
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
217
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
218
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
219
+ f"Make sure your NPU is properly installed and operational."
220
+ )
221
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
222
+
223
+ self.net = RBLNPytorchRuntime(runtime)
224
+
225
+ def save_pretrained(self, checkpoint_id: str):
226
+ cache_path = pathlib.Path(checkpoint_id) / "face_blur_filter"
227
+ cache_path.mkdir(parents=True, exist_ok=True)
228
+ self.compiled_model.save(cache_path / "retinaface.rbln")
229
+
230
+
231
+ class RBLNVideoSafetyModel(VideoSafetyModel):
232
+ def __init__(
233
+ self,
234
+ config: ModelConfig,
235
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
236
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
237
+ ):
238
+ torch.nn.Module.__init__(self)
239
+ self.config = config
240
+ self.num_classes = config.num_classes
241
+ self.rbln_config = rbln_config
242
+
243
+ if is_compiled_dir(checkpoint_id):
244
+ self.compiled_model = rebel.RBLNCompiledModel(
245
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "safety_filter.rbln"
246
+ )
247
+ else:
248
+ # Load model from checkpoint
249
+ network = SafetyClassifier(
250
+ input_size=self.rbln_config.video_safety_model.input_size, num_classes=self.num_classes
251
+ )
252
+ network.eval()
253
+
254
+ checkpoint_dir = snapshot_download(checkpoint_id)
255
+ checkpoint_dir = (pathlib.Path(checkpoint_dir) / "video_content_safety_filter").as_posix()
256
+
257
+ safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt")
258
+ checkpoint = torch.load(safety_filter_local_path, weights_only=True)
259
+ network.load_state_dict({k.replace("network.", ""): v for k, v in checkpoint["model"].items()})
260
+
261
+ self.compiled_model = rebel.compile_from_torch(
262
+ network,
263
+ input_info=[
264
+ (
265
+ "data",
266
+ [
267
+ self.rbln_config.video_safety_model.batch_size,
268
+ self.rbln_config.video_safety_model.input_size,
269
+ ],
270
+ "float32",
271
+ )
272
+ ],
273
+ npu=self.rbln_config.video_safety_model.npu,
274
+ )
275
+
276
+ try:
277
+ runtime = (
278
+ rebel.Runtime(
279
+ self.compiled_model,
280
+ tensor_type="pt",
281
+ device=self.rbln_config.video_safety_model.device,
282
+ activate_profiler=rbln_config.video_safety_model.activate_profiler,
283
+ )
284
+ if self.rbln_config.video_safety_model.create_runtimes
285
+ else UnavailableRuntime()
286
+ )
287
+ except rebel.core.exception.RBLNRuntimeError as e:
288
+ error_msg = (
289
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
290
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
291
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
292
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
293
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
294
+ f"Make sure your NPU is properly installed and operational."
295
+ )
296
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
297
+
298
+ self.network = RBLNPytorchRuntime(runtime)
299
+
300
+ def save_pretrained(self, checkpoint_id: str):
301
+ cache_path = pathlib.Path(checkpoint_id) / "video_content_safety_filter"
302
+ cache_path.mkdir(parents=True, exist_ok=True)
303
+ self.compiled_model.save(cache_path / "safety_filter.rbln")
304
+
305
+ def parameters(self):
306
+ yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
307
+
308
+
309
+ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
310
+ def __init__(
311
+ self,
312
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
313
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
314
+ ):
315
+ torch.nn.Module.__init__(self)
316
+ self.rbln_config = rbln_config
317
+ self.encoder = RBLNSigLIPEncoder(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
318
+
319
+ model_config = ModelConfig(input_size=1152, num_classes=7)
320
+ self.model = RBLNVideoSafetyModel(model_config, checkpoint_id=checkpoint_id, rbln_config=rbln_config)
321
+
322
+ def save_pretrained(self, checkpoint_id: str):
323
+ self.model.save_pretrained(checkpoint_id)
324
+ self.encoder.save_pretrained(checkpoint_id)
325
+
326
+
327
+ class RBLNAegis(Aegis):
328
+ def __init__(
329
+ self,
330
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
331
+ base_model_id: str = "meta-llama/LlamaGuard-7b",
332
+ aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
333
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
334
+ ) -> None:
335
+ if is_compiled_dir(checkpoint_id):
336
+ torch.nn.Module.__init__(self)
337
+ cache_dir = pathlib.Path(checkpoint_id) / "aegis"
338
+ self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
339
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(
340
+ cache_dir,
341
+ rbln_device=rbln_config.aegis.device,
342
+ rbln_create_runtimes=rbln_config.aegis.create_runtimes,
343
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
344
+ rbln_optimize_host_memory=rbln_config.aegis.optimize_host_memory,
345
+ )
346
+
347
+ else:
348
+ super().__init__(checkpoint_id, base_model_id, aegis_adapter)
349
+ model = self.model.merge_and_unload() # peft merge
350
+ del self.model
351
+
352
+ self.model = RBLNAutoModelForCausalLM.from_model(
353
+ model,
354
+ rbln_tensor_parallel_size=4,
355
+ rbln_device=rbln_config.aegis.device,
356
+ rbln_create_runtimes=rbln_config.aegis.create_runtimes,
357
+ rbln_npu=rbln_config.aegis.npu,
358
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
359
+ rbln_optimize_host_memory=rbln_config.aegis.optimize_host_memory,
360
+ )
361
+
362
+ self.rbln_config = rbln_config
363
+ self.dtype = torch.bfloat16
364
+ self.device = torch.device("cpu")
365
+
366
+ def save_pretrained(self, checkpoint_id: str):
367
+ cache_dir = pathlib.Path(checkpoint_id) / "aegis"
368
+ self.model.save_pretrained(cache_dir)
369
+ self.tokenizer.save_pretrained(cache_dir)
370
+
371
+
372
+ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
373
+ """
374
+ RBLN-accelerated implementation of Cosmos Safety Checker.
375
+ """
376
+
377
+ def __init__(
378
+ self,
379
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
380
+ aegis_model_id: str = "meta-llama/LlamaGuard-7b",
381
+ aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
382
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
383
+ ) -> None:
384
+ torch.nn.Module.__init__(self)
385
+ if not COSMOS_AVAILABLE:
386
+ raise ImportError(
387
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
388
+ )
389
+
390
+ if rbln_config is None:
391
+ rbln_config = RBLNCosmosSafetyCheckerConfig()
392
+ elif isinstance(rbln_config, dict):
393
+ rbln_config = RBLNCosmosSafetyCheckerConfig(**rbln_config)
394
+
395
+ self.text_guardrail = GuardrailRunner(
396
+ safety_models=[
397
+ Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
398
+ RBLNAegis(
399
+ checkpoint_id=checkpoint_id,
400
+ base_model_id=aegis_model_id,
401
+ aegis_adapter=aegis_adapter_id,
402
+ rbln_config=rbln_config,
403
+ ),
404
+ ]
405
+ )
406
+
407
+ self.video_guardrail = GuardrailRunner(
408
+ safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
409
+ postprocessors=[RBLNRetinaFaceFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
410
+ )
411
+
412
+ self.rbln_config = rbln_config
413
+
414
+ def save_pretrained(self, save_dir: str):
415
+ for text_safety_models in self.text_guardrail.safety_models:
416
+ if isinstance(text_safety_models, RBLNAegis):
417
+ text_safety_models.save_pretrained(save_dir)
418
+
419
+ for video_safety_models in self.video_guardrail.safety_models:
420
+ if isinstance(video_safety_models, RBLNVideoContentSafetyFilter):
421
+ video_safety_models.save_pretrained(save_dir)
422
+
423
+ for postprocessors in self.video_guardrail.postprocessors:
424
+ if isinstance(postprocessors, RBLNRetinaFaceFilter):
425
+ postprocessors.save_pretrained(save_dir)
426
+
427
+ self.rbln_config._frozen = True # Ad-hoc to save config
428
+ self.rbln_config.save(save_dir)
429
+
430
+ @classmethod
431
+ def from_pretrained(
432
+ cls,
433
+ checkpoint_id: str,
434
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
435
+ subfolder: Optional[str] = None,
436
+ export: Optional[bool] = True,
437
+ **kwargs,
438
+ ):
439
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
440
+
441
+ if len(kwargs) > 0:
442
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
443
+
444
+ if subfolder is not None:
445
+ checkpoint_id = os.path.join(checkpoint_id, subfolder)
446
+
447
+ return cls(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
448
+
449
+ @classmethod
450
+ def prepare_rbln_config(
451
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNCosmosSafetyCheckerConfig]] = None, **kwargs
452
+ ) -> Tuple[RBLNCosmosSafetyCheckerConfig, Dict[str, Any]]:
453
+ # Extract rbln-config from kwargs and convert it to RBLNCosmosSafetyCheckerConfig.
454
+ rbln_config, kwargs = RBLNCosmosSafetyCheckerConfig.initialize_from_kwargs(rbln_config, **kwargs)
455
+ return rbln_config, kwargs
@@ -0,0 +1,98 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosTextToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Text to World pipeline for text-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Text to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating videos with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosTextToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_submodules = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Dict[str, Any],
91
+ ):
92
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
+ if safety_checker is None and export:
94
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
95
+
96
+ return super().from_pretrained(
97
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
98
+ )
@@ -0,0 +1,98 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosVideoToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Video to World pipeline for video-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Video to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating videos with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosVideoToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_submodules = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Dict[str, Any],
91
+ ):
92
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
+ if safety_checker is None and export:
94
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
95
+
96
+ return super().from_pretrained(
97
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
98
+ )
@@ -19,6 +19,13 @@ from ...modeling_diffusers import RBLNDiffusionMixin
19
19
 
20
20
 
21
21
  class RBLNKandinskyV22Pipeline(RBLNDiffusionMixin, KandinskyV22Pipeline):
22
+ """
23
+ RBLN-accelerated implementation of Kandinsky 2.2 pipeline for text-to-image generation.
24
+
25
+ This pipeline compiles Kandinsky 2.2 models to run efficiently on RBLN NPUs, enabling high-performance
26
+ inference for generating images with distinctive artistic style and enhanced visual quality.
27
+ """
28
+
22
29
  original_class = KandinskyV22Pipeline
23
30
  _rbln_config_class = RBLNKandinskyV22PipelineConfig
24
31
  _submodules = ["unet", "movq"]