optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
16
+ from .pipeline_cosmos_text2world import RBLNCosmosTextToWorldPipeline
17
+ from .pipeline_cosmos_video2world import RBLNCosmosVideoToWorldPipeline
@@ -0,0 +1,113 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
+
20
+
21
+ class RBLNVideoSafetyModelConfig(RBLNModelConfig):
22
+ """
23
+ Configuration class for RBLN Video Content Safety Filter.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ batch_size: Optional[int] = None,
29
+ input_size: Optional[int] = None,
30
+ image_size: Optional[Tuple[int, int]] = None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.batch_size = batch_size or 1
35
+ self.input_size = input_size or 1152
36
+
37
+
38
+ class RBLNRetinaFaceFilterConfig(RBLNModelConfig):
39
+ """
40
+ Configuration class for RBLN Retina Face Filter.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ batch_size: Optional[int] = None,
46
+ image_size: Optional[Tuple[int, int]] = None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ self.image_size = image_size or (704, 1280)
52
+
53
+
54
+ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
55
+ """
56
+ Configuration class for RBLN Cosmos Safety Checker.
57
+ """
58
+
59
+ submodules = ["llamaguard3", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
+
61
+ def __init__(
62
+ self,
63
+ llamaguard3: Optional[RBLNModelConfig] = None,
64
+ video_safety_model: Optional[RBLNModelConfig] = None,
65
+ face_blur_filter: Optional[RBLNModelConfig] = None,
66
+ siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
67
+ *,
68
+ batch_size: Optional[int] = None,
69
+ image_size: Optional[Tuple[int, int]] = None,
70
+ height: Optional[int] = None,
71
+ width: Optional[int] = None,
72
+ max_seq_len: Optional[int] = None,
73
+ **kwargs: Any,
74
+ ):
75
+ super().__init__(**kwargs)
76
+ if height is not None and width is not None:
77
+ image_size = (height, width)
78
+
79
+ if max_seq_len is None:
80
+ max_seq_len = 512
81
+
82
+ tensor_parallel_size = kwargs.get("tensor_parallel_size")
83
+
84
+ self.llamaguard3 = self.initialize_submodule_config(
85
+ llamaguard3,
86
+ cls_name="RBLNLlamaForCausalLMConfig",
87
+ batch_size=batch_size,
88
+ tensor_parallel_size=tensor_parallel_size,
89
+ max_seq_len=max_seq_len,
90
+ )
91
+ self.siglip_encoder = self.initialize_submodule_config(
92
+ siglip_encoder,
93
+ cls_name="RBLNSiglipVisionModelConfig",
94
+ batch_size=batch_size,
95
+ image_size=(384, 384),
96
+ )
97
+ self.video_safety_model = self.initialize_submodule_config(
98
+ video_safety_model,
99
+ cls_name="RBLNVideoSafetyModelConfig",
100
+ batch_size=batch_size,
101
+ input_size=1152,
102
+ )
103
+ self.face_blur_filter = self.initialize_submodule_config(
104
+ face_blur_filter,
105
+ cls_name="RBLNRetinaFaceFilterConfig",
106
+ batch_size=batch_size,
107
+ image_size=image_size,
108
+ )
109
+
110
+
111
+ RBLNAutoConfig.register(RBLNVideoSafetyModelConfig)
112
+ RBLNAutoConfig.register(RBLNRetinaFaceFilterConfig)
113
+ RBLNAutoConfig.register(RBLNCosmosSafetyCheckerConfig)
@@ -0,0 +1,425 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import pathlib
17
+ from functools import partial
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+ from unittest.mock import patch
20
+
21
+ import rebel
22
+ import torch
23
+ from diffusers.utils import is_cosmos_guardrail_available
24
+ from huggingface_hub import snapshot_download
25
+ from transformers import AutoTokenizer, SiglipProcessor
26
+
27
+ from .... import RBLNAutoModelForCausalLM, RBLNSiglipVisionModel
28
+ from ....utils.runtime_utils import RBLNPytorchRuntime, UnavailableRuntime
29
+ from .configuration_cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
30
+
31
+
32
+ if is_cosmos_guardrail_available():
33
+ from cosmos_guardrail import CosmosSafetyChecker
34
+ from cosmos_guardrail.cosmos_guardrail import (
35
+ COSMOS_GUARDRAIL_CHECKPOINT,
36
+ Blocklist,
37
+ GuardrailRunner,
38
+ LlamaGuard3,
39
+ ModelConfig,
40
+ RetinaFaceFilter,
41
+ SafetyClassifier,
42
+ SigLIPEncoder,
43
+ VideoContentSafetyFilter,
44
+ VideoSafetyModel,
45
+ )
46
+ from retinaface.data import cfg_re50
47
+
48
+ COSMOS_AVAILABLE = True
49
+ else:
50
+ COSMOS_AVAILABLE = False
51
+
52
+ class FailToImportCosmosGuardrail(torch.nn.Module): ...
53
+
54
+ class CosmosSafetyChecker(FailToImportCosmosGuardrail): ...
55
+
56
+ COSMOS_GUARDRAIL_CHECKPOINT = None
57
+
58
+ class LlamaGuard3(FailToImportCosmosGuardrail): ...
59
+
60
+ class Blocklist(FailToImportCosmosGuardrail): ...
61
+
62
+ class GuardrailRunner(FailToImportCosmosGuardrail): ...
63
+
64
+ class ModelConfig(FailToImportCosmosGuardrail): ...
65
+
66
+ class RetinaFaceFilter(FailToImportCosmosGuardrail): ...
67
+
68
+ class SafetyClassifier(FailToImportCosmosGuardrail): ...
69
+
70
+ class SigLIPEncoder(FailToImportCosmosGuardrail): ...
71
+
72
+ class VideoContentSafetyFilter(FailToImportCosmosGuardrail): ...
73
+
74
+ class VideoSafetyModel(FailToImportCosmosGuardrail): ...
75
+
76
+ cfg_re50 = None
77
+
78
+
79
+ def is_compiled_dir(dir: str) -> bool:
80
+ # walk directory and check if there is any *.rbln files in that dir.
81
+ if not os.path.exists(dir):
82
+ return False
83
+
84
+ for root, dirs, files in os.walk(dir):
85
+ for file in files:
86
+ if file.endswith(".rbln"):
87
+ return True
88
+ return False
89
+
90
+
91
+ def get_image_features(
92
+ self,
93
+ pixel_values: torch.Tensor,
94
+ return_dict: bool = True,
95
+ output_attentions: bool = False,
96
+ output_hidden_states: bool = False,
97
+ interpolate_pos_encoding: bool = False,
98
+ ):
99
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
100
+ output_hidden_states = (
101
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
102
+ )
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+
105
+ return self(
106
+ pixel_values,
107
+ return_dict=return_dict,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ interpolate_pos_encoding=interpolate_pos_encoding,
111
+ )[1]
112
+
113
+
114
+ class RBLNSigLIPEncoder(SigLIPEncoder):
115
+ def __init__(
116
+ self,
117
+ model_name: str = "google/siglip-so400m-patch14-384",
118
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
119
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
120
+ ):
121
+ torch.nn.Module.__init__(self)
122
+ if is_compiled_dir(checkpoint_id):
123
+ self.checkpoint_dir = (
124
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder"
125
+ ).as_posix()
126
+ self.processor = SiglipProcessor.from_pretrained(self.checkpoint_dir)
127
+
128
+ # We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
129
+ self.model = RBLNSiglipVisionModel.from_pretrained(
130
+ self.checkpoint_dir, rbln_config=rbln_config.siglip_encoder
131
+ )
132
+ else:
133
+ super().__init__(model_name, checkpoint_id)
134
+ model = self.model
135
+ del self.model
136
+ self.model = RBLNSiglipVisionModel.from_model(model, rbln_config=rbln_config.siglip_encoder)
137
+ self.rbln_config = rbln_config
138
+
139
+ # Override get_image_features to return pooler_output
140
+ self.model.get_image_features = lambda *args, **kwargs: get_image_features(self.model, *args, **kwargs)
141
+
142
+ def save_pretrained(self, checkpoint_id: str):
143
+ cache_dir = (pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "siglip_encoder").as_posix()
144
+ self.model.save_pretrained(cache_dir)
145
+ self.processor.save_pretrained(cache_dir)
146
+
147
+
148
+ class RBLNRetinaFaceFilter(RetinaFaceFilter):
149
+ def __init__(
150
+ self,
151
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
152
+ batch_size: int = 1,
153
+ confidence_threshold: float = 0.7,
154
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
155
+ ):
156
+ torch.nn.Module.__init__(self)
157
+ if is_compiled_dir(checkpoint_id):
158
+ self.compiled_model = rebel.RBLNCompiledModel(
159
+ pathlib.Path(checkpoint_id) / "face_blur_filter" / "retinaface.rbln"
160
+ )
161
+ self.cfg = cfg_re50
162
+ self.batch_size = batch_size
163
+ self.confidence_threshold = confidence_threshold
164
+ self.cfg["pretrain"] = False
165
+ else:
166
+ with patch("torch.load", partial(torch.load, weights_only=True, map_location=torch.device("cpu"))):
167
+ super().__init__(checkpoint_id)
168
+ net = self.net
169
+ del self.net
170
+ self.compiled_model = rebel.compile_from_torch(
171
+ net,
172
+ input_info=[
173
+ (
174
+ "frames",
175
+ [
176
+ self.batch_size,
177
+ 3,
178
+ rbln_config.face_blur_filter.image_size[0],
179
+ rbln_config.face_blur_filter.image_size[1],
180
+ ],
181
+ "float32",
182
+ )
183
+ ],
184
+ npu=rbln_config.face_blur_filter.npu,
185
+ )
186
+
187
+ self.rbln_config = rbln_config
188
+
189
+ try:
190
+ runtime = (
191
+ rebel.Runtime(
192
+ self.compiled_model,
193
+ tensor_type="pt",
194
+ device=self.rbln_config.face_blur_filter.device,
195
+ activate_profiler=rbln_config.face_blur_filter.activate_profiler,
196
+ )
197
+ if self.rbln_config.face_blur_filter.create_runtimes
198
+ else UnavailableRuntime()
199
+ )
200
+ except rebel.core.exception.RBLNRuntimeError as e:
201
+ error_msg = (
202
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
203
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
204
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
205
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
206
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
207
+ f"Make sure your NPU is properly installed and operational."
208
+ )
209
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
210
+
211
+ self.net = RBLNPytorchRuntime(runtime)
212
+
213
+ def save_pretrained(self, checkpoint_id: str):
214
+ cache_path = pathlib.Path(checkpoint_id) / "face_blur_filter"
215
+ cache_path.mkdir(parents=True, exist_ok=True)
216
+ self.compiled_model.save(cache_path / "retinaface.rbln")
217
+
218
+
219
+ class RBLNVideoSafetyModel(VideoSafetyModel):
220
+ def __init__(
221
+ self,
222
+ config: ModelConfig,
223
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
224
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
225
+ ):
226
+ torch.nn.Module.__init__(self)
227
+ self.config = config
228
+ self.num_classes = config.num_classes
229
+ self.rbln_config = rbln_config
230
+
231
+ if is_compiled_dir(checkpoint_id):
232
+ self.compiled_model = rebel.RBLNCompiledModel(
233
+ pathlib.Path(checkpoint_id) / "video_content_safety_filter" / "safety_filter.rbln"
234
+ )
235
+ else:
236
+ # Load model from checkpoint
237
+ network = SafetyClassifier(
238
+ input_size=self.rbln_config.video_safety_model.input_size, num_classes=self.num_classes
239
+ )
240
+ network.eval()
241
+
242
+ checkpoint_dir = snapshot_download(checkpoint_id)
243
+ checkpoint_dir = (pathlib.Path(checkpoint_dir) / "video_content_safety_filter").as_posix()
244
+
245
+ safety_filter_local_path = os.path.join(checkpoint_dir, "safety_filter.pt")
246
+ checkpoint = torch.load(safety_filter_local_path, weights_only=True)
247
+ network.load_state_dict({k.replace("network.", ""): v for k, v in checkpoint["model"].items()})
248
+
249
+ self.compiled_model = rebel.compile_from_torch(
250
+ network,
251
+ input_info=[
252
+ (
253
+ "data",
254
+ [
255
+ self.rbln_config.video_safety_model.batch_size,
256
+ self.rbln_config.video_safety_model.input_size,
257
+ ],
258
+ "float32",
259
+ )
260
+ ],
261
+ npu=self.rbln_config.video_safety_model.npu,
262
+ )
263
+
264
+ try:
265
+ runtime = (
266
+ rebel.Runtime(
267
+ self.compiled_model,
268
+ tensor_type="pt",
269
+ device=self.rbln_config.video_safety_model.device,
270
+ activate_profiler=rbln_config.video_safety_model.activate_profiler,
271
+ )
272
+ if self.rbln_config.video_safety_model.create_runtimes
273
+ else UnavailableRuntime()
274
+ )
275
+ except rebel.core.exception.RBLNRuntimeError as e:
276
+ error_msg = (
277
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
278
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
279
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
280
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
281
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
282
+ f"Make sure your NPU is properly installed and operational."
283
+ )
284
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
285
+
286
+ self.network = RBLNPytorchRuntime(runtime)
287
+
288
+ def save_pretrained(self, checkpoint_id: str):
289
+ cache_path = pathlib.Path(checkpoint_id) / "video_content_safety_filter"
290
+ cache_path.mkdir(parents=True, exist_ok=True)
291
+ self.compiled_model.save(cache_path / "safety_filter.rbln")
292
+
293
+ def parameters(self):
294
+ yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
295
+
296
+
297
+ class RBLNVideoContentSafetyFilter(VideoContentSafetyFilter):
298
+ def __init__(
299
+ self,
300
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
301
+ rbln_config: Optional["RBLNCosmosSafetyCheckerConfig"] = None,
302
+ ):
303
+ torch.nn.Module.__init__(self)
304
+ self.rbln_config = rbln_config
305
+ self.encoder = RBLNSigLIPEncoder(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
306
+
307
+ model_config = ModelConfig(input_size=1152, num_classes=7)
308
+ self.model = RBLNVideoSafetyModel(model_config, checkpoint_id=checkpoint_id, rbln_config=rbln_config)
309
+
310
+ def save_pretrained(self, checkpoint_id: str):
311
+ self.model.save_pretrained(checkpoint_id)
312
+ self.encoder.save_pretrained(checkpoint_id)
313
+
314
+
315
+ class RBLNLlamaGuard3(LlamaGuard3):
316
+ def __init__(
317
+ self,
318
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
319
+ base_model_id: str = "meta-llama/Llama-Guard-3-8B",
320
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
321
+ ) -> None:
322
+ if is_compiled_dir(checkpoint_id):
323
+ torch.nn.Module.__init__(self)
324
+ cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
325
+ self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
326
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.llamaguard3)
327
+
328
+ else:
329
+ super().__init__(checkpoint_id, base_model_id)
330
+ model = self.model
331
+ del self.model
332
+ self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.llamaguard3)
333
+
334
+ self.rbln_config = rbln_config
335
+ self.dtype = torch.bfloat16
336
+ self.device = torch.device("cpu")
337
+
338
+ def save_pretrained(self, checkpoint_id: str):
339
+ cache_dir = pathlib.Path(checkpoint_id) / "llamaguard3"
340
+ self.model.save_pretrained(cache_dir)
341
+ self.tokenizer.save_pretrained(cache_dir)
342
+
343
+
344
+ class RBLNCosmosSafetyChecker(CosmosSafetyChecker):
345
+ """
346
+ RBLN-accelerated implementation of Cosmos Safety Checker.
347
+ """
348
+
349
+ def __init__(
350
+ self,
351
+ checkpoint_id: str = COSMOS_GUARDRAIL_CHECKPOINT,
352
+ llamaguard_model_id: str = "meta-llama/Llama-Guard-3-8B",
353
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
354
+ ) -> None:
355
+ torch.nn.Module.__init__(self)
356
+ if not COSMOS_AVAILABLE:
357
+ raise ImportError(
358
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
359
+ )
360
+
361
+ if rbln_config is None:
362
+ rbln_config = RBLNCosmosSafetyCheckerConfig()
363
+ elif isinstance(rbln_config, dict):
364
+ rbln_config = RBLNCosmosSafetyCheckerConfig(**rbln_config)
365
+
366
+ self.text_guardrail = GuardrailRunner(
367
+ safety_models=[
368
+ Blocklist(COSMOS_GUARDRAIL_CHECKPOINT), # Changed since it cannot be saved
369
+ RBLNLlamaGuard3(
370
+ checkpoint_id=checkpoint_id,
371
+ base_model_id=llamaguard_model_id,
372
+ rbln_config=rbln_config,
373
+ ),
374
+ ]
375
+ )
376
+
377
+ self.video_guardrail = GuardrailRunner(
378
+ safety_models=[RBLNVideoContentSafetyFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
379
+ postprocessors=[RBLNRetinaFaceFilter(checkpoint_id=checkpoint_id, rbln_config=rbln_config)],
380
+ )
381
+
382
+ self.rbln_config = rbln_config
383
+
384
+ def save_pretrained(self, save_dir: str):
385
+ for text_safety_models in self.text_guardrail.safety_models:
386
+ if isinstance(text_safety_models, RBLNLlamaGuard3):
387
+ text_safety_models.save_pretrained(save_dir)
388
+
389
+ for video_safety_models in self.video_guardrail.safety_models:
390
+ if isinstance(video_safety_models, RBLNVideoContentSafetyFilter):
391
+ video_safety_models.save_pretrained(save_dir)
392
+
393
+ for postprocessors in self.video_guardrail.postprocessors:
394
+ if isinstance(postprocessors, RBLNRetinaFaceFilter):
395
+ postprocessors.save_pretrained(save_dir)
396
+
397
+ self.rbln_config._frozen = True # Ad-hoc to save config
398
+ self.rbln_config.save(save_dir)
399
+
400
+ @classmethod
401
+ def from_pretrained(
402
+ cls,
403
+ checkpoint_id: str,
404
+ rbln_config: Optional[RBLNCosmosSafetyCheckerConfig] = None,
405
+ subfolder: Optional[str] = None,
406
+ export: Optional[bool] = True,
407
+ **kwargs,
408
+ ):
409
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
410
+
411
+ if len(kwargs) > 0:
412
+ raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
413
+
414
+ if subfolder is not None:
415
+ checkpoint_id = os.path.join(checkpoint_id, subfolder)
416
+
417
+ return cls(checkpoint_id=checkpoint_id, rbln_config=rbln_config)
418
+
419
+ @classmethod
420
+ def prepare_rbln_config(
421
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNCosmosSafetyCheckerConfig]] = None, **kwargs
422
+ ) -> Tuple[RBLNCosmosSafetyCheckerConfig, Dict[str, Any]]:
423
+ # Extract rbln-config from kwargs and convert it to RBLNCosmosSafetyCheckerConfig.
424
+ rbln_config, kwargs = RBLNCosmosSafetyCheckerConfig.initialize_from_kwargs(rbln_config, **kwargs)
425
+ return rbln_config, kwargs
@@ -0,0 +1,128 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ from diffusers import CosmosTextToWorldPipeline
19
+ from diffusers.schedulers import EDMEulerScheduler
20
+ from transformers import T5TokenizerFast
21
+
22
+ from ....transformers.models.t5.modeling_t5 import RBLNT5EncoderModel
23
+ from ....utils.logging import get_logger
24
+ from ...modeling_diffusers import RBLNDiffusionMixin
25
+ from ...models.autoencoders.autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
26
+ from ...models.transformers.transformer_cosmos import RBLNCosmosTransformer3DModel
27
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipeline):
34
+ """
35
+ RBLN-accelerated implementation of Cosmos Text to World pipeline for text-to-video generation.
36
+
37
+ This pipeline compiles Cosmos Text to World models to run efficiently on RBLN NPUs, enabling high-performance
38
+ inference for generating videos with distinctive artistic style and enhanced visual quality.
39
+ """
40
+
41
+ original_class = CosmosTextToWorldPipeline
42
+ _submodules = ["text_encoder", "transformer", "vae"]
43
+ _optional_submodules = ["safety_checker"]
44
+
45
+ def __init__(
46
+ self,
47
+ text_encoder: RBLNT5EncoderModel,
48
+ tokenizer: T5TokenizerFast,
49
+ transformer: RBLNCosmosTransformer3DModel,
50
+ vae: RBLNAutoencoderKLCosmos,
51
+ scheduler: EDMEulerScheduler,
52
+ safety_checker: RBLNCosmosSafetyChecker = None,
53
+ ):
54
+ if safety_checker is None:
55
+ safety_checker = RBLNCosmosSafetyChecker()
56
+
57
+ super().__init__(
58
+ text_encoder=text_encoder,
59
+ tokenizer=tokenizer,
60
+ transformer=transformer,
61
+ vae=vae,
62
+ scheduler=scheduler,
63
+ safety_checker=safety_checker,
64
+ )
65
+
66
+ def handle_additional_kwargs(self, **kwargs):
67
+ if "num_frames" in kwargs and kwargs["num_frames"] != self.transformer.rbln_config.num_frames:
68
+ logger.warning(
69
+ f"The transformer in this pipeline is compiled with 'num_frames={self.transformer.rbln_config.num_frames}'. 'num_frames' set by the user will be ignored"
70
+ )
71
+ kwargs.pop("num_frames")
72
+ if (
73
+ "max_sequence_length" in kwargs
74
+ and kwargs["max_sequence_length"] != self.transformer.rbln_config.max_seq_len
75
+ ):
76
+ logger.warning(
77
+ f"The transformer in this pipeline is compiled with 'max_seq_len={self.transformer.rbln_config.max_seq_len}'. 'max_sequence_length' set by the user will be ignored"
78
+ )
79
+ kwargs.pop("max_sequence_length")
80
+ return kwargs
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ model_id: str,
86
+ *,
87
+ export: bool = False,
88
+ safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
+ rbln_config: Dict[str, Any] = {},
90
+ **kwargs: Any,
91
+ ):
92
+ """
93
+ Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
94
+
95
+ This method has two distinct operating modes:
96
+ - When `export=True`: Takes a PyTorch-based diffusion model, compiles it for RBLN NPUs, and loads the compiled model
97
+ - When `export=False`: Loads an already compiled RBLN model from `model_id` without recompilation
98
+
99
+ It supports various diffusion pipelines including Stable Diffusion, Kandinsky, ControlNet, and other diffusers-based models.
100
+
101
+ Args:
102
+ model_id (`str`):
103
+ The model ID or path to the pretrained model to load. Can be either:
104
+
105
+ - A model ID from the HuggingFace Hub
106
+ - A local path to a saved model directory
107
+ export:
108
+ If True, takes a PyTorch model from `model_id` and compiles it for RBLN NPU execution.
109
+ If False, loads an already compiled RBLN model from `model_id` without recompilation.
110
+ safety_checker:
111
+ Optional custom safety checker to use instead of the default one. Only used when `export=True`.
112
+ rbln_config:
113
+ Configuration options for RBLN compilation. Can include settings for specific submodules
114
+ such as `text_encoder`, `unet`, and `vae`. Configuration can be tailored to the specific
115
+ pipeline being compiled.
116
+ kwargs:
117
+ Additional arguments to pass to the underlying diffusion pipeline constructor or the
118
+ RBLN compilation process. These may include parameters specific to individual submodules
119
+ or the particular diffusion pipeline being used.
120
+ """
121
+
122
+ rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
123
+ if safety_checker is None and export:
124
+ safety_checker = RBLNCosmosSafetyChecker(rbln_config=rbln_config.safety_checker)
125
+
126
+ return super().from_pretrained(
127
+ model_id, export=export, safety_checker=safety_checker, rbln_config=rbln_config, **kwargs
128
+ )