optimum-rbln 0.8.1a5__py3-none-any.whl → 0.8.1a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. optimum/rbln/__init__.py +18 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +21 -1
  4. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  5. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
  8. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +0 -4
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +0 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +0 -4
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -4
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +0 -4
  15. optimum/rbln/diffusers/modeling_diffusers.py +57 -40
  16. optimum/rbln/diffusers/models/__init__.py +4 -0
  17. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  18. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +6 -1
  19. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  20. optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
  21. optimum/rbln/diffusers/models/autoencoders/vq_model.py +6 -1
  22. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  25. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +451 -0
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  30. optimum/rbln/modeling.py +38 -2
  31. optimum/rbln/modeling_base.py +18 -2
  32. optimum/rbln/transformers/modeling_generic.py +3 -3
  33. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  34. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  35. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  36. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  37. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +13 -1
  38. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -0
  39. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -2
  40. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -1
  41. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +22 -20
  42. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +6 -1
  43. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +8 -0
  44. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  45. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  46. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  47. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +8 -0
  48. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +16 -0
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -0
  50. optimum/rbln/transformers/models/resnet/configuration_resnet.py +6 -1
  51. optimum/rbln/transformers/models/resnet/modeling_resnet.py +5 -1
  52. optimum/rbln/transformers/models/roberta/configuration_roberta.py +12 -2
  53. optimum/rbln/transformers/models/roberta/modeling_roberta.py +16 -0
  54. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +6 -2
  55. optimum/rbln/transformers/models/siglip/configuration_siglip.py +7 -0
  56. optimum/rbln/transformers/models/siglip/modeling_siglip.py +7 -0
  57. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  58. optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
  59. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +7 -0
  60. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +6 -2
  61. optimum/rbln/transformers/models/vit/configuration_vit.py +6 -1
  62. optimum/rbln/transformers/models/vit/modeling_vit.py +7 -1
  63. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +7 -0
  64. optimum/rbln/transformers/models/whisper/configuration_whisper.py +7 -0
  65. optimum/rbln/transformers/models/whisper/modeling_whisper.py +6 -2
  66. optimum/rbln/utils/runtime_utils.py +49 -1
  67. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/METADATA +1 -1
  68. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/RECORD +70 -60
  69. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/WHEEL +0 -0
  70. {optimum_rbln-0.8.1a5.dist-info → optimum_rbln-0.8.1a7.dist-info}/licenses/LICENSE +0 -0
@@ -25,6 +25,11 @@ _import_structure = {
25
25
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
26
26
  "RBLNStableDiffusionXLControlNetPipeline",
27
27
  ],
28
+ "cosmos": [
29
+ "RBLNCosmosTextToWorldPipeline",
30
+ "RBLNCosmosVideoToWorldPipeline",
31
+ "RBLNCosmosSafetyChecker",
32
+ ],
28
33
  "kandinsky2_2": [
29
34
  "RBLNKandinskyV22CombinedPipeline",
30
35
  "RBLNKandinskyV22Img2ImgCombinedPipeline",
@@ -58,6 +63,11 @@ if TYPE_CHECKING:
58
63
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
59
64
  RBLNStableDiffusionXLControlNetPipeline,
60
65
  )
66
+ from .cosmos import (
67
+ RBLNCosmosSafetyChecker,
68
+ RBLNCosmosTextToWorldPipeline,
69
+ RBLNCosmosVideoToWorldPipeline,
70
+ )
61
71
  from .kandinsky2_2 import (
62
72
  RBLNKandinskyV22CombinedPipeline,
63
73
  RBLNKandinskyV22Img2ImgCombinedPipeline,
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
16
+ from .pipeline_cosmos_text2world import RBLNCosmosTextToWorldPipeline
17
+ from .pipeline_cosmos_video2world import RBLNCosmosVideoToWorldPipeline
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
+
20
+
21
+ class RBLNVideoSafetyModelConfig(RBLNModelConfig):
22
+ """
23
+ Configuration class for RBLN Video Content Safety Filter.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ batch_size: Optional[int] = None,
29
+ input_size: Optional[int] = None,
30
+ image_size: Optional[Tuple[int, int]] = None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.batch_size = batch_size or 1
35
+ self.input_size = input_size or 1152
36
+
37
+
38
+ class RBLNRetinaFaceFilterConfig(RBLNModelConfig):
39
+ """
40
+ Configuration class for RBLN Retina Face Filter.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ batch_size: Optional[int] = None,
46
+ image_size: Optional[Tuple[int, int]] = None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ self.image_size = image_size or (704, 1280)
52
+
53
+
54
+ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
55
+ """
56
+ Configuration class for RBLN Cosmos Safety Checker.
57
+ """
58
+
59
+ submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
+
61
+ def __init__(
62
+ self,
63
+ aegis: Optional[RBLNModelConfig] = None,
64
+ video_safety_model: Optional[RBLNModelConfig] = None,
65
+ face_blur_filter: Optional[RBLNModelConfig] = None,
66
+ siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
67
+ *,
68
+ batch_size: Optional[int] = None,
69
+ image_size: Optional[Tuple[int, int]] = None,
70
+ height: Optional[int] = None,
71
+ width: Optional[int] = None,
72
+ **kwargs: Dict[str, Any],
73
+ ):
74
+ super().__init__(**kwargs)
75
+ if height is not None and width is not None:
76
+ image_size = (height, width)
77
+
78
+ self.aegis = self.init_submodule_config(RBLNModelConfig, aegis)
79
+ self.siglip_encoder = self.init_submodule_config(
80
+ RBLNSiglipVisionModelConfig,
81
+ siglip_encoder,
82
+ batch_size=batch_size,
83
+ image_size=(384, 384),
84
+ )
85
+
86
+ self.video_safety_model = self.init_submodule_config(
87
+ RBLNVideoSafetyModelConfig,
88
+ video_safety_model,
89
+ batch_size=batch_size,
90
+ input_size=1152,
91
+ )
92
+ self.face_blur_filter = self.init_submodule_config(
93
+ RBLNRetinaFaceFilterConfig,
94
+ face_blur_filter,
95
+ batch_size=batch_size,
96
+ image_size=image_size,
97
+ )
98
+
99
+
100
+ RBLNAutoConfig.register(RBLNVideoSafetyModelConfig)
101
+ RBLNAutoConfig.register(RBLNRetinaFaceFilterConfig)
102
+ RBLNAutoConfig.register(RBLNCosmosSafetyCheckerConfig)
@@ -0,0 +1,451 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import pathlib
17
+ from functools import partial
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+ from unittest.mock import patch
20
+
21
+ import rebel
22
+ import torch
23
+ from diffusers.utils import is_cosmos_guardrail_available
24
+ from huggingface_hub import snapshot_download
25
+ from transformers import AutoTokenizer, SiglipProcessor
26
+
27
+ from .... import RBLNAutoModelForCausalLM, RBLNSiglipVisionModel
28
+ from ....utils.runtime_utils import RBLNPytorchRuntime, UnavailableRuntime
29
+ from .configuration_cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
30
+
31
+
32
+ if is_cosmos_guardrail_available():
33
+ from cosmos_guardrail import CosmosSafetyChecker
34
+ from cosmos_guardrail.cosmos_guardrail import (
35
+ COSMOS_GUARDRAIL_CHECKPOINT,
36
+ Aegis,
37
+ Blocklist,
38
+ GuardrailRunner,
39
+ ModelConfig,
40
+ RetinaFaceFilter,
41
+ SafetyClassifier,
42
+ SigLIPEncoder,
43
+ VideoContentSafetyFilter,
44
+ VideoSafetyModel,
45
+ )
46
+ from retinaface.data import cfg_re50
47
+
48
+ COSMOS_AVAILABLE = True
49
+ else:
50
+ COSMOS_AVAILABLE = False
51
+
52
+ class FailToImportCosmosGuardrail(torch.nn.Module): ...
53
+
54
+ class CosmosSafetyChecker(FailToImportCosmosGuardrail): ...
55
+
56
+ COSMOS_GUARDRAIL_CHECKPOINT = None
57
+
58
+ class Aegis(FailToImportCosmosGuardrail): ...
59
+
60
+ class Blocklist(FailToImportCosmosGuardrail): ...
61
+
62
+ class GuardrailRunner(FailToImportCosmosGuardrail): ...
63
+
64
+ class ModelConfig(FailToImportCosmosGuardrail): ...
65
+
66
+ class RetinaFaceFilter(FailToImportCosmosGuardrail): ...
67
+
68
+ class SafetyClassifier(FailToImportCosmosGuardrail): ...
69
+
70
+ class SigLIPEncoder(FailToImportCosmosGuardrail): ...
71
+
72
+ class VideoContentSafetyFilter(FailToImportCosmosGuardrail): ...
73
+
74
+ class VideoSafetyModel(FailToImportCosmosGuardrail): ...
75
+
76
+ cfg_re50 = None
77
+
78
+
79
+ def is_compiled_dir(dir: str) -> bool:
80
+ # walk directory and check if there is any *.rbln files in that dir.
81
+ if not os.path.exists(dir):
82
+ return False
83
+
84
+ for root, dirs, files in os.walk(dir):
85
+ for file in files:
86
+ if file.endswith(".rbln"):
87
+ return True
88
+ return False
89
+
90
+
91
+ def get_image_features(
92
+ self,
93
+ pixel_values: torch.Tensor,
94
+ return_dict: bool = True,
95
+ output_attentions: bool = False,
96
+ output_hidden_states: bool = False,
97
+ interpolate_pos_encoding: bool = False,
98
+ ):
99
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
100
+ output_hidden_states = (
101
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
102
+ )
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+
105
+ return self(
106
+ pixel_values,
107
+ return_dict=return_dict,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ interpolate_pos_encoding=interpolate_pos_encoding,
111
+ )[1]
112
+
113
+
114
+ class RBLNSigLIPEncoder(SigLIPEncoder):
115
+ def __init__(
116
+ self,
117
+ model_name: str = "google/siglip-so400m-patch14-384",
118
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
119
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
120
+ ):
121
+ torch.nn.Module.__init__(self)
122
+ if is_compiled_dir(checkpoint_id):
123
+ self.checkpoint_dir = (
124
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder"
125
+ ).as_posix()
126
+ self.processor = SiglipProcessor.from_pretrained(self.checkpoint_dir)
127
+
128
+ # We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
129
+ self.model = RBLNSiglipVisionModel.from_pretrained(
130
+ self.checkpoint_dir,
131
+ rbln_device=rbln_config.siglip_encoder.device,
132
+ rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
133
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
134
+ )
135
+ else:
136
+ super().__init__(model_name, checkpoint_id)
137
+ model = self.model
138
+ del self.model
139
+ self.model = RBLNSiglipVisionModel.from_model(
140
+ model,
141
+ rbln_device=rbln_config.siglip_encoder.device,
142
+ rbln_image_size=rbln_config.siglip_encoder.image_size,
143
+ rbln_npu=rbln_config.siglip_encoder.npu,
144
+ rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
145
+ rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
146
+ )
147
+ self.rbln_config = rbln_config
148
+
149
+ # Override get_image_features to return pooler_output
150
+ self.model.get_image_features = lambda *args, **kwargs: get_image_features(self.model, *args, **kwargs)
151
+
152
+ def save_pretrained(self, checkpoint_id: str):
153
+ cache_dir = (pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder").as_posix()
154
+ self.model.save_pretrained(cache_dir)
155
+ self.processor.save_pretrained(cache_dir)
156
+
157
+
158
+ class RBLNRetinaFaceFilter(RetinaFaceFilter):
159
+ def __init__(
160
+ self,
161
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
162
+ batch_size: int = 1,
163
+ confidence_threshold: float = 0.7,
164
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
165
+ ):
166
+ torch.nn.Module.__init__(self)
167
+ if is_compiled_dir(checkpoint_id):
168
+ self.compiled_model = rebel.RBLNCompiledModel(
169
+ pathlib.Path(checkpoint_id) / "face_blur_filter" / "retinaface.rbln"
170
+ )
171
+ self.cfg = cfg_re50
172
+ self.batch_size = batch_size
173
+ self.confidence_threshold = confidence_threshold
174
+ self.cfg["pretrain"] = False
175
+ else:
176
+ with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
177
+ super().__init__(checkpoint_id)
178
+ net = self.net
179
+ del self.net
180
+ self.compiled_model = rebel.compile_from_torch(
181
+ net,
182
+ input_info=[
183
+ (
184
+ "frames",
185
+ [
186
+ self.batch_size,
187
+ 3,
188
+ rbln_config.face_blur_filter.image_size[0],
189
+ rbln_config.face_blur_filter.image_size[1],
190
+ ],
191
+ "float32",
192
+ )
193
+ ],
194
+ npu=rbln_config.face_blur_filter.npu,
195
+ )
196
+
197
+ self.rbln_config = rbln_config
198
+
199
+ try:
200
+ runtime = (
201
+ rebel.Runtime(
202
+ self.compiled_model,
203
+ tensor_type="pt",
204
+ device=self.rbln_config.face_blur_filter.device,
205
+ activate_profiler=rbln_config.face_blur_filter.activate_profiler,
206
+ )
207
+ if self.rbln_config.face_blur_filter.create_runtimes
208
+ else UnavailableRuntime()
209
+ )
210
+ except rebel.core.exception.RBLNRuntimeError as e:
211
+ error_msg = (
212
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
213
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
214
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
215
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
216
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
217
+ f"Make sure your NPU is properly installed and operational."
218
+ )
219
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
220
+
221
+ self.net = RBLNPytorchRuntime(runtime)
222
+
223
+ def save_pretrained(self, checkpoint_id: str):
224
+ cache_path = pathlib.Path(checkpoint_id) / "face_blur_filter"
225
+ cache_path.mkdir(parents=True, exist_ok=True)
226
+ self.compiled_model.save(cache_path / "retinaface.rbln")
227
+
228
+
229
+ class RBLNVideoSafetyModel(VideoSafetyModel):
230
+ def __init__(
231
+ self,
232
+ config: ModelConfig,
233
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
234
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
235
+ ):
236
+ torch.nn.Module.__init__(self)
237
+ self.config = config
238
+ self.num_classes = config.num_classes
239
+ self.rbln_config = rbln_config
240
+
241
+ if is_compiled_dir(checkpoint_id):
242
+ self.compiled_model = rebel.RBLNCompiledModel(
243
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "safety_filter.rbln"
244
+ )
245
+ else:
246
+ # Load model from checkpoint
247
+ network = SafetyClassifier(
248
+ input_size=self.rbln_config.video_safety_model.input_size, num_classes=self.num_classes
249
+ )
250
+ network.eval()
251
+
252
+ checkpoint_dir = snapshot_download(checkpoint_id)
253
+ checkpoint_dir = (pathlib.Path(checkpoint_dir) / "video_content_safety_filter").as_posix()
254
+
255
+ safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt")
256
+ checkpoint = torch.load(safety_filter_local_path, weights_only=True)
257
+ network.load_state_dict({k.replace("network.", ""): v for k, v in checkpoint["model"].items()})
258
+
259
+ self.compiled_model = rebel.compile_from_torch(
260
+ network,
261
+ input_info=[
262
+ (
263
+ "data",
264
+ [
265
+ self.rbln_config.video_safety_model.batch_size,
266
+ self.rbln_config.video_safety_model.input_size,
267
+ ],
268
+ "float32",
269
+ )
270
+ ],
271
+ npu=self.rbln_config.video_safety_model.npu,
272
+ )
273
+
274
+ try:
275
+ runtime = (
276
+ rebel.Runtime(
277
+ self.compiled_model,
278
+ tensor_type="pt",
279
+ device=self.rbln_config.video_safety_model.device,
280
+ activate_profiler=rbln_config.video_safety_model.activate_profiler,
281
+ )
282
+ if self.rbln_config.video_safety_model.create_runtimes
283
+ else UnavailableRuntime()
284
+ )
285
+ except rebel.core.exception.RBLNRuntimeError as e:
286
+ error_msg = (
287
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
288
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
289
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
290
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
291
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
292
+ f"Make sure your NPU is properly installed and operational."
293
+ )
294
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
295
+
296
+ self.network = RBLNPytorchRuntime(runtime)
297
+
298
+ def save_pretrained(self, checkpoint_id: str):
299
+ cache_path = pathlib.Path(checkpoint_id) / "video_content_safety_filter"
300
+ cache_path.mkdir(parents=True, exist_ok=True)
301
+ self.compiled_model.save(cache_path / "safety_filter.rbln")
302
+
303
+ def parameters(self):
304
+ yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
305
+
306
+
307
+ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
308
+ def __init__(
309
+ self,
310
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
311
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
312
+ ):
313
+ torch.nn.Module.__init__(self)
314
+ self.rbln_config = rbln_config
315
+ self.encoder = RBLNSigLIPEncoder(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
316
+
317
+ model_config = ModelConfig(input_size=1152, num_classes=7)
318
+ self.model = RBLNVideoSafetyModel(model_config, checkpoint_id=checkpoint_id, rbln_config=rbln_config)
319
+
320
+ def save_pretrained(self, checkpoint_id: str):
321
+ self.model.save_pretrained(checkpoint_id)
322
+ self.encoder.save_pretrained(checkpoint_id)
323
+
324
+
325
+ class RBLNAegis(Aegis):
326
+ def __init__(
327
+ self,
328
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
329
+ base_model_id: str = "meta-llama/LlamaGuard-7b",
330
+ aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
331
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
332
+ ) -> None:
333
+ if is_compiled_dir(checkpoint_id):
334
+ torch.nn.Module.__init__(self)
335
+ cache_dir = pathlib.Path(checkpoint_id) / "aegis"
336
+ self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
337
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(
338
+ cache_dir,
339
+ rbln_device=rbln_config.aegis.device,
340
+ rbln_create_runtimes=rbln_config.aegis.create_runtimes,
341
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
342
+ )
343
+
344
+ else:
345
+ super().__init__(checkpoint_id, base_model_id, aegis_adapter)
346
+ model = self.model.merge_and_unload() # peft merge
347
+ del self.model
348
+
349
+ self.model = RBLNAutoModelForCausalLM.from_model(
350
+ model,
351
+ rbln_tensor_parallel_size=4,
352
+ rbln_device=rbln_config.aegis.device,
353
+ rbln_create_runtimes=rbln_config.aegis.create_runtimes,
354
+ rbln_npu=rbln_config.aegis.npu,
355
+ rbln_activate_profiler=rbln_config.aegis.activate_profiler,
356
+ )
357
+
358
+ self.rbln_config = rbln_config
359
+ self.dtype = torch.bfloat16
360
+ self.device = torch.device("cpu")
361
+
362
+ def save_pretrained(self, checkpoint_id: str):
363
+ cache_dir = pathlib.Path(checkpoint_id) / "aegis"
364
+ self.model.save_pretrained(cache_dir)
365
+ self.tokenizer.save_pretrained(cache_dir)
366
+
367
+
368
+ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
369
+ """
370
+ RBLN-accelerated implementation of Cosmos Safety Checker.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
376
+ aegis_model_id: str = "meta-llama/LlamaGuard-7b",
377
+ aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
378
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
379
+ ) -> None:
380
+ torch.nn.Module.__init__(self)
381
+ if not COSMOS_AVAILABLE:
382
+ raise ImportError(
383
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
384
+ )
385
+
386
+ if rbln_config is None:
387
+ rbln_config = RBLNCosmosSafetyCheckerConfig()
388
+ elif isinstance(rbln_config, dict):
389
+ rbln_config = RBLNCosmosSafetyCheckerConfig(**rbln_config)
390
+
391
+ self.text_guardrail = GuardrailRunner(
392
+ safety_models=[
393
+ Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
394
+ RBLNAegis(
395
+ checkpoint_id=checkpoint_id,
396
+ base_model_id=aegis_model_id,
397
+ aegis_adapter=aegis_adapter_id,
398
+ rbln_config=rbln_config,
399
+ ),
400
+ ]
401
+ )
402
+
403
+ self.video_guardrail = GuardrailRunner(
404
+ safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
405
+ postprocessors=[RBLNRetinaFaceFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
406
+ )
407
+
408
+ self.rbln_config = rbln_config
409
+
410
+ def save_pretrained(self, save_dir: str):
411
+ for text_safety_models in self.text_guardrail.safety_models:
412
+ if isinstance(text_safety_models, RBLNAegis):
413
+ text_safety_models.save_pretrained(save_dir)
414
+
415
+ for video_safety_models in self.video_guardrail.safety_models:
416
+ if isinstance(video_safety_models, RBLNVideoContentSafetyFilter):
417
+ video_safety_models.save_pretrained(save_dir)
418
+
419
+ for postprocessors in self.video_guardrail.postprocessors:
420
+ if isinstance(postprocessors, RBLNRetinaFaceFilter):
421
+ postprocessors.save_pretrained(save_dir)
422
+
423
+ self.rbln_config._frozen = True # Ad-hoc to save config
424
+ self.rbln_config.save(save_dir)
425
+
426
+ @classmethod
427
+ def from_pretrained(
428
+ cls,
429
+ checkpoint_id: str,
430
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
431
+ subfolder: Optional[str] = None,
432
+ export: Optional[bool] = True,
433
+ **kwargs,
434
+ ):
435
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
436
+
437
+ if len(kwargs) > 0:
438
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
439
+
440
+ if subfolder is not None:
441
+ checkpoint_id = os.path.join(checkpoint_id, subfolder)
442
+
443
+ return cls(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
444
+
445
+ @classmethod
446
+ def prepare_rbln_config(
447
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNCosmosSafetyCheckerConfig]] = None, **kwargs
448
+ ) -> Tuple[RBLNCosmosSafetyCheckerConfig, Dict[str, Any]]:
449
+ # Extract rbln-config from kwargs and convert it to RBLNCosmosSafetyCheckerConfig.
450
+ rbln_config, kwargs = RBLNCosmosSafetyCheckerConfig.initialize_from_kwargs(rbln_config, **kwargs)
451
+ return rbln_config, kwargs
@@ -0,0 +1,98 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosTextToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Text to World pipeline for text-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Text to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating images with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosTextToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_submodules = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Dict[str, Any],
91
+ ):
92
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
+ if safety_checker is None and export:
94
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
95
+
96
+ return super().from_pretrained(
97
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
98
+ )