optimum-rbln 0.8.1a4__py3-none-any.whl → 0.8.1a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. optimum/rbln/__init__.py +22 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +21 -1
  4. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  5. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +82 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_cosmos_transformer.py +68 -0
  8. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +110 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +1 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +41 -22
  12. optimum/rbln/diffusers/models/__init__.py +4 -0
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +209 -0
  15. optimum/rbln/diffusers/models/autoencoders/vae.py +49 -5
  16. optimum/rbln/diffusers/models/controlnet.py +1 -1
  17. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  18. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  19. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  20. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  21. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  22. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +395 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  25. optimum/rbln/transformers/__init__.py +2 -0
  26. optimum/rbln/transformers/models/__init__.py +8 -0
  27. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  28. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  29. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  30. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  31. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +2 -2
  32. optimum/rbln/transformers/models/t5/modeling_t5.py +10 -4
  33. optimum/rbln/utils/runtime_utils.py +3 -0
  34. {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/METADATA +4 -4
  35. {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/RECORD +37 -23
  36. {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/WHEEL +0 -0
  37. {optimum_rbln-0.8.1a4.dist-info → optimum_rbln-0.8.1a6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
+
20
+
21
+ class RBLNVideoSafetyModelConfig(RBLNModelConfig):
22
+ """
23
+ Configuration class for RBLN Video Content Safety Filter.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ batch_size: Optional[int] = None,
29
+ input_size: Optional[int] = None,
30
+ image_size: Optional[Tuple[int, int]] = None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.batch_size = batch_size or 1
35
+ self.input_size = input_size or 1152
36
+
37
+
38
+ class RBLNRetinaFaceFilterConfig(RBLNModelConfig):
39
+ """
40
+ Configuration class for RBLN Retina Face Filter.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ batch_size: Optional[int] = None,
46
+ image_size: Optional[Tuple[int, int]] = None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ self.image_size = image_size or (704, 1280)
52
+
53
+
54
+ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
55
+ """
56
+ Configuration class for RBLN Cosmos Safety Checker.
57
+ """
58
+
59
+ submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
+
61
+ def __init__(
62
+ self,
63
+ aegis: Optional[RBLNModelConfig] = None,
64
+ video_safety_model: Optional[RBLNModelConfig] = None,
65
+ face_blur_filter: Optional[RBLNModelConfig] = None,
66
+ siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
67
+ *,
68
+ batch_size: Optional[int] = None,
69
+ image_size: Optional[Tuple[int, int]] = None,
70
+ height: Optional[int] = None,
71
+ width: Optional[int] = None,
72
+ **kwargs: Dict[str, Any],
73
+ ):
74
+ super().__init__(**kwargs)
75
+ if height is not None and width is not None:
76
+ image_size = (height, width)
77
+
78
+ self.aegis = self.init_submodule_config(RBLNModelConfig, aegis)
79
+ self.siglip_encoder = self.init_submodule_config(
80
+ RBLNSiglipVisionModelConfig,
81
+ siglip_encoder,
82
+ batch_size=batch_size,
83
+ image_size=(384, 384),
84
+ )
85
+
86
+ self.video_safety_model = self.init_submodule_config(
87
+ RBLNVideoSafetyModelConfig,
88
+ video_safety_model,
89
+ batch_size=batch_size,
90
+ input_size=1152,
91
+ )
92
+ self.face_blur_filter = self.init_submodule_config(
93
+ RBLNRetinaFaceFilterConfig,
94
+ face_blur_filter,
95
+ batch_size=batch_size,
96
+ image_size=image_size,
97
+ )
98
+
99
+
100
+ RBLNAutoConfig.register(RBLNVideoSafetyModelConfig)
101
+ RBLNAutoConfig.register(RBLNRetinaFaceFilterConfig)
102
+ RBLNAutoConfig.register(RBLNCosmosSafetyCheckerConfig)
@@ -0,0 +1,395 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import pathlib
17
+ from functools import partial
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+ from unittest.mock import patch
20
+
21
+ import rebel
22
+ import torch
23
+ from diffusers.utils import is_cosmos_guardrail_available
24
+ from huggingface_hub import snapshot_download
25
+ from transformers import AutoTokenizer, SiglipProcessor
26
+
27
+ from .... import RBLNAutoModelForCausalLM, RBLNSiglipVisionModel
28
+ from ....utils.runtime_utils import RBLNPytorchRuntime
29
+ from .configuration_cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
30
+
31
+
32
+ if is_cosmos_guardrail_available():
33
+ from cosmos_guardrail import CosmosSafetyChecker
34
+ from cosmos_guardrail.cosmos_guardrail import (
35
+ COSMOS_GUARDRAIL_CHECKPOINT,
36
+ Aegis,
37
+ Blocklist,
38
+ GuardrailRunner,
39
+ ModelConfig,
40
+ RetinaFaceFilter,
41
+ SafetyClassifier,
42
+ SigLIPEncoder,
43
+ VideoContentSafetyFilter,
44
+ VideoSafetyModel,
45
+ )
46
+ from retinaface.data import cfg_re50
47
+
48
+ COSMOS_AVAILABLE = True
49
+ else:
50
+ COSMOS_AVAILABLE = False
51
+
52
+ class FailToImportCosmosGuardrail(torch.nn.Module): ...
53
+
54
+ class CosmosSafetyChecker(FailToImportCosmosGuardrail): ...
55
+
56
+ COSMOS_GUARDRAIL_CHECKPOINT = None
57
+
58
+ class Aegis(FailToImportCosmosGuardrail): ...
59
+
60
+ class Blocklist(FailToImportCosmosGuardrail): ...
61
+
62
+ class GuardrailRunner(FailToImportCosmosGuardrail): ...
63
+
64
+ class ModelConfig(FailToImportCosmosGuardrail): ...
65
+
66
+ class RetinaFaceFilter(FailToImportCosmosGuardrail): ...
67
+
68
+ class SafetyClassifier(FailToImportCosmosGuardrail): ...
69
+
70
+ class SigLIPEncoder(FailToImportCosmosGuardrail): ...
71
+
72
+ class VideoContentSafetyFilter(FailToImportCosmosGuardrail): ...
73
+
74
+ class VideoSafetyModel(FailToImportCosmosGuardrail): ...
75
+
76
+ cfg_re50 = None
77
+
78
+
79
+ def is_compiled_dir(dir: str) -> bool:
80
+ # walk directory and check if there is any *.rbln files in that dir.
81
+ if not os.path.exists(dir):
82
+ return False
83
+
84
+ for root, dirs, files in os.walk(dir):
85
+ for file in files:
86
+ if file.endswith(".rbln"):
87
+ return True
88
+ return False
89
+
90
+
91
+ def get_image_features(
92
+ self,
93
+ pixel_values: torch.Tensor,
94
+ return_dict: bool = True,
95
+ output_attentions: bool = False,
96
+ output_hidden_states: bool = False,
97
+ interpolate_pos_encoding: bool = False,
98
+ ):
99
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
100
+ output_hidden_states = (
101
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
102
+ )
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+
105
+ return self(
106
+ pixel_values,
107
+ return_dict=return_dict,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ interpolate_pos_encoding=interpolate_pos_encoding,
111
+ )[1]
112
+
113
+
114
+ class RBLNSigLIPEncoder(SigLIPEncoder):
115
+ def __init__(
116
+ self,
117
+ model_name: str = "google/siglip-so400m-patch14-384",
118
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
119
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
120
+ ):
121
+ torch.nn.Module.__init__(self)
122
+ if is_compiled_dir(checkpoint_id):
123
+ self.checkpoint_dir = (
124
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder"
125
+ ).as_posix()
126
+ self.processor = SiglipProcessor.from_pretrained(self.checkpoint_dir)
127
+
128
+ # We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
129
+ self.model = RBLNSiglipVisionModel.from_pretrained(
130
+ self.checkpoint_dir,
131
+ rbln_device=rbln_config.siglip_encoder.device,
132
+ )
133
+ else:
134
+ super().__init__(model_name, checkpoint_id)
135
+ model = self.model
136
+ del self.model
137
+ self.model = RBLNSiglipVisionModel.from_model(
138
+ model,
139
+ rbln_device=rbln_config.siglip_encoder.device,
140
+ rbln_image_size=rbln_config.siglip_encoder.image_size,
141
+ rbln_npu=rbln_config.siglip_encoder.npu,
142
+ )
143
+ self.rbln_config = rbln_config
144
+
145
+ # Override get_image_features to return pooler_output
146
+ self.model.get_image_features = lambda *args, **kwargs: get_image_features(self.model, *args, **kwargs)
147
+
148
+ def save_pretrained(self, checkpoint_id: str):
149
+ cache_dir = (pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder").as_posix()
150
+ self.model.save_pretrained(cache_dir)
151
+ self.processor.save_pretrained(cache_dir)
152
+
153
+
154
+ class RBLNRetinaFaceFilter(RetinaFaceFilter):
155
+ def __init__(
156
+ self,
157
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
158
+ batch_size: int = 1,
159
+ confidence_threshold: float = 0.7,
160
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
161
+ ):
162
+ torch.nn.Module.__init__(self)
163
+ if is_compiled_dir(checkpoint_id):
164
+ self.compiled_model = rebel.RBLNCompiledModel(
165
+ pathlib.Path(checkpoint_id) / "face_blur_filter" / "retinaface.rbln"
166
+ )
167
+ self.cfg = cfg_re50
168
+ self.batch_size = batch_size
169
+ self.confidence_threshold = confidence_threshold
170
+ self.cfg["pretrain"] = False
171
+ else:
172
+ with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
173
+ super().__init__(checkpoint_id)
174
+ net = self.net
175
+ del self.net
176
+ self.compiled_model = rebel.compile_from_torch(
177
+ net,
178
+ input_info=[
179
+ (
180
+ "frames",
181
+ [
182
+ self.batch_size,
183
+ 3,
184
+ rbln_config.face_blur_filter.image_size[0],
185
+ rbln_config.face_blur_filter.image_size[1],
186
+ ],
187
+ "float32",
188
+ )
189
+ ],
190
+ npu=rbln_config.face_blur_filter.npu,
191
+ )
192
+
193
+ self.rbln_config = rbln_config
194
+ runtime = rebel.Runtime(self.compiled_model, tensor_type="pt", device=self.rbln_config.face_blur_filter.device)
195
+ self.net = RBLNPytorchRuntime(runtime)
196
+
197
+ def save_pretrained(self, checkpoint_id: str):
198
+ cache_path = pathlib.Path(checkpoint_id) / "face_blur_filter"
199
+ cache_path.mkdir(parents=True, exist_ok=True)
200
+ self.compiled_model.save(cache_path / "retinaface.rbln")
201
+
202
+
203
+ class RBLNVideoSafetyModel(VideoSafetyModel):
204
+ def __init__(
205
+ self,
206
+ config: ModelConfig,
207
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
208
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
209
+ ):
210
+ torch.nn.Module.__init__(self)
211
+ self.config = config
212
+ self.num_classes = config.num_classes
213
+ self.rbln_config = rbln_config
214
+
215
+ if is_compiled_dir(checkpoint_id):
216
+ self.compiled_model = rebel.RBLNCompiledModel(
217
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "safety_filter.rbln"
218
+ )
219
+ else:
220
+ # Load model from checkpoint
221
+ network = SafetyClassifier(
222
+ input_size=self.rbln_config.video_safety_model.input_size, num_classes=self.num_classes
223
+ )
224
+ network.eval()
225
+
226
+ checkpoint_dir = snapshot_download(checkpoint_id)
227
+ checkpoint_dir = (pathlib.Path(checkpoint_dir) / "video_content_safety_filter").as_posix()
228
+
229
+ safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt")
230
+ checkpoint = torch.load(safety_filter_local_path, weights_only=True)
231
+ network.load_state_dict({k.replace("network.", ""): v for k, v in checkpoint["model"].items()})
232
+
233
+ self.compiled_model = rebel.compile_from_torch(
234
+ network,
235
+ input_info=[
236
+ (
237
+ "data",
238
+ [
239
+ self.rbln_config.video_safety_model.batch_size,
240
+ self.rbln_config.video_safety_model.input_size,
241
+ ],
242
+ "float32",
243
+ )
244
+ ],
245
+ npu=self.rbln_config.video_safety_model.npu,
246
+ )
247
+
248
+ runtime = rebel.Runtime(
249
+ self.compiled_model,
250
+ tensor_type="pt",
251
+ device=self.rbln_config.video_safety_model.device,
252
+ )
253
+ self.network = RBLNPytorchRuntime(runtime)
254
+
255
+ def save_pretrained(self, checkpoint_id: str):
256
+ cache_path = pathlib.Path(checkpoint_id) / "video_content_safety_filter"
257
+ cache_path.mkdir(parents=True, exist_ok=True)
258
+ self.compiled_model.save(cache_path / "safety_filter.rbln")
259
+
260
+ def parameters(self):
261
+ yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
262
+
263
+
264
+ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
265
+ def __init__(
266
+ self,
267
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
268
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
269
+ ):
270
+ torch.nn.Module.__init__(self)
271
+ self.rbln_config = rbln_config
272
+ self.encoder = RBLNSigLIPEncoder(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
273
+
274
+ model_config = ModelConfig(input_size=1152, num_classes=7)
275
+ self.model = RBLNVideoSafetyModel(model_config, checkpoint_id=checkpoint_id, rbln_config=rbln_config)
276
+
277
+ def save_pretrained(self, checkpoint_id: str):
278
+ self.model.save_pretrained(checkpoint_id)
279
+ self.encoder.save_pretrained(checkpoint_id)
280
+
281
+
282
+ class RBLNAegis(Aegis):
283
+ def __init__(
284
+ self,
285
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
286
+ base_model_id: str = "meta-llama/LlamaGuard-7b",
287
+ aegis_adapter: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
288
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
289
+ ) -> None:
290
+ if is_compiled_dir(checkpoint_id):
291
+ torch.nn.Module.__init__(self)
292
+ cache_dir = pathlib.Path(checkpoint_id) / "aegis"
293
+ self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
294
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_device=rbln_config.aegis.device)
295
+
296
+ else:
297
+ super().__init__(checkpoint_id, base_model_id, aegis_adapter)
298
+ model = self.model.merge_and_unload() # peft merge
299
+ del self.model
300
+
301
+ self.model = RBLNAutoModelForCausalLM.from_model(
302
+ model,
303
+ rbln_tensor_parallel_size=4,
304
+ rbln_device=rbln_config.aegis.device,
305
+ rbln_npu=rbln_config.aegis.npu,
306
+ )
307
+
308
+ self.rbln_config = rbln_config
309
+ self.dtype = torch.bfloat16
310
+ self.device = torch.device("cpu")
311
+
312
+ def save_pretrained(self, checkpoint_id: str):
313
+ cache_dir = pathlib.Path(checkpoint_id) / "aegis"
314
+ self.model.save_pretrained(cache_dir)
315
+ self.tokenizer.save_pretrained(cache_dir)
316
+
317
+
318
+ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
319
+ """
320
+ RBLN-accelerated implementation of Cosmos Safety Checker.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
326
+ aegis_model_id: str = "meta-llama/LlamaGuard-7b",
327
+ aegis_adapter_id: str = "nvidia/Aegis-AI-Content-Safety-LlamaGuard-Defensive-1.0",
328
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
329
+ ) -> None:
330
+ torch.nn.Module.__init__(self)
331
+ if not COSMOS_AVAILABLE:
332
+ raise ImportError(
333
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
334
+ )
335
+
336
+ if rbln_config is None:
337
+ rbln_config = RBLNCosmosSafetyCheckerConfig()
338
+
339
+ self.text_guardrail = GuardrailRunner(
340
+ safety_models=[
341
+ Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
342
+ RBLNAegis(checkpoint_id, aegis_model_id, aegis_adapter_id, rbln_config=rbln_config),
343
+ ]
344
+ )
345
+
346
+ with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
347
+ self.video_guardrail = GuardrailRunner(
348
+ safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id, rbln_config=rbln_config)],
349
+ postprocessors=[RBLNRetinaFaceFilter(checkpoint_id, rbln_config=rbln_config)],
350
+ )
351
+
352
+ self.rbln_config = rbln_config
353
+
354
+ def save_pretrained(self, save_dir: str):
355
+ for text_safety_models in self.text_guardrail.safety_models:
356
+ if isinstance(text_safety_models, RBLNAegis):
357
+ text_safety_models.save_pretrained(save_dir)
358
+
359
+ for video_safety_models in self.video_guardrail.safety_models:
360
+ if isinstance(video_safety_models, RBLNVideoContentSafetyFilter):
361
+ video_safety_models.save_pretrained(save_dir)
362
+
363
+ for postprocessors in self.video_guardrail.postprocessors:
364
+ if isinstance(postprocessors, RBLNRetinaFaceFilter):
365
+ postprocessors.save_pretrained(save_dir)
366
+
367
+ self.rbln_config._frozen = True # Ad-hoc to save config
368
+ self.rbln_config.save(save_dir)
369
+
370
+ @classmethod
371
+ def from_pretrained(
372
+ cls,
373
+ checkpoint_id: str,
374
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
375
+ subfolder: Optional[str] = None,
376
+ export: Optional[bool] = True,
377
+ **kwargs,
378
+ ):
379
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
380
+
381
+ if len(kwargs) > 0:
382
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
383
+
384
+ if subfolder is not None:
385
+ checkpoint_id = os.path.join(checkpoint_id, subfolder)
386
+
387
+ return cls(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
388
+
389
+ @classmethod
390
+ def prepare_rbln_config(
391
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNCosmosSafetyCheckerConfig]] = None, **kwargs
392
+ ) -> Tuple[RBLNCosmosSafetyCheckerConfig, Dict[str, Any]]:
393
+ # Extract rbln-config from kwargs and convert it to RBLNCosmosSafetyCheckerConfig.
394
+ rbln_config, kwargs = RBLNCosmosSafetyCheckerConfig.initialize_from_kwargs(rbln_config, **kwargs)
395
+ return rbln_config, kwargs
@@ -0,0 +1,98 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosTextToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Text to World pipeline for text-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Text to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating images with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosTextToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_submodules = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Dict[str, Any],
91
+ ):
92
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
+ if safety_checker is None and export:
94
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
95
+
96
+ return super().from_pretrained(
97
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
98
+ )
@@ -0,0 +1,98 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosVideoToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Video to World pipeline for video-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Video to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating images with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosVideoToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_components = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Dict[str, Any],
91
+ ):
92
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
+ if safety_checker is None and export:
94
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
95
+
96
+ return super().from_pretrained(
97
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
98
+ )