optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,344 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, List, Optional, Union
17
+
18
+ import rebel
19
+ import torch
20
+ from diffusers import CosmosTransformer3DModel
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.models.transformers.transformer_cosmos import (
23
+ CosmosEmbedding,
24
+ CosmosLearnablePositionalEmbed,
25
+ CosmosPatchEmbed,
26
+ CosmosRotaryPosEmbed,
27
+ )
28
+ from torchvision import transforms
29
+
30
+ from ....configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNModelConfig
31
+ from ....modeling import RBLNModel
32
+ from ....utils.logging import get_logger
33
+ from ...configurations import RBLNCosmosTransformer3DModelConfig
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
38
+
39
+ from ...modeling_diffusers import RBLNCosmosTransformer3DModelConfig, RBLNDiffusionMixin, RBLNDiffusionMixinConfig
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+
45
+ class CosmosTransformer3DModelWrapper(torch.nn.Module):
46
+ def __init__(
47
+ self,
48
+ model: CosmosTransformer3DModel,
49
+ num_latent_frames: int = 16,
50
+ latent_height: int = 88,
51
+ latent_width: int = 160,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.model = model
55
+ self.num_latent_frames = num_latent_frames
56
+ self.latent_height = latent_height
57
+ self.latent_width = latent_width
58
+ self.p_t, self.p_h, self.p_w = model.config.patch_size
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ encoder_hidden_states: torch.Tensor,
64
+ embedded_timestep: torch.Tensor,
65
+ temb: torch.Tensor,
66
+ image_rotary_emb_0: torch.Tensor,
67
+ image_rotary_emb_1: torch.Tensor,
68
+ extra_pos_emb: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ return_dict: bool = False,
71
+ ):
72
+ image_rotary_emb = [image_rotary_emb_0, image_rotary_emb_1]
73
+ for block in self.model.transformer_blocks:
74
+ hidden_states = block(
75
+ hidden_states=hidden_states,
76
+ encoder_hidden_states=encoder_hidden_states,
77
+ embedded_timestep=embedded_timestep,
78
+ temb=temb,
79
+ image_rotary_emb=image_rotary_emb,
80
+ extra_pos_emb=extra_pos_emb,
81
+ attention_mask=attention_mask,
82
+ )
83
+ post_patch_num_frames = self.num_latent_frames // self.p_t
84
+ post_patch_height = self.latent_height // self.p_h
85
+ post_patch_width = self.latent_width // self.p_w
86
+ hidden_states = self.model.norm_out(hidden_states, embedded_timestep, temb)
87
+ hidden_states = self.model.proj_out(hidden_states)
88
+ hidden_states = hidden_states.unflatten(2, (self.p_h, self.p_w, self.p_t, -1))
89
+ hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
90
+ hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
91
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
92
+
93
+ return (hidden_states,)
94
+
95
+
96
+ class RBLNCosmosTransformer3DModel(RBLNModel):
97
+ """
98
+ RBLN implementation of CosmosTransformer3DModel for diffusion models like Cosmos.
99
+
100
+ The CosmosTransformer3DModel takes text and/or image embeddings from encoders (like CLIP) and
101
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
102
+
103
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
104
+ the library implements for all its models.
105
+ """
106
+
107
+ hf_library_name = "diffusers"
108
+ auto_model_class = CosmosTransformer3DModel
109
+
110
+ def __post_init__(self, **kwargs):
111
+ super().__post_init__(**kwargs)
112
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
113
+
114
+ hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
115
+ patch_embed_in_channels = (
116
+ self.config.in_channels + 1 if self.config.concat_padding_mask else self.config.in_channels
117
+ )
118
+ self.rope = CosmosRotaryPosEmbed(
119
+ hidden_size=self.config.attention_head_dim,
120
+ max_size=self.config.max_size,
121
+ patch_size=self.config.patch_size,
122
+ rope_scale=self.config.rope_scale,
123
+ )
124
+ self.rope.load_state_dict(artifacts["rope"])
125
+ if artifacts["learnable_pos_embed"] is None:
126
+ self.learnable_pos_embed = None
127
+ else:
128
+ self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
129
+ hidden_size=hidden_size,
130
+ max_size=self.config.max_size,
131
+ patch_size=self.config.patch_size,
132
+ )
133
+ self.learnable_pos_embed.load_state_dict(artifacts["learnable_pos_embed"])
134
+ self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, self.config.patch_size, bias=False)
135
+ self.patch_embed.load_state_dict(artifacts["patch_embed"])
136
+ self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
137
+ self.time_embed.load_state_dict(artifacts["time_embed"])
138
+
139
+ def compute_embedding(
140
+ self,
141
+ hidden_states: torch.Tensor,
142
+ timestep: torch.Tensor,
143
+ attention_mask: Optional[torch.Tensor] = None,
144
+ fps: Optional[int] = None,
145
+ condition_mask: Optional[torch.Tensor] = None,
146
+ padding_mask: Optional[torch.Tensor] = None,
147
+ ):
148
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
149
+
150
+ # 1. Concatenate padding mask if needed & prepare attention mask
151
+ if condition_mask is not None:
152
+ hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
153
+
154
+ if self.config.concat_padding_mask:
155
+ padding_mask = transforms.functional.resize(
156
+ padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
157
+ )
158
+ hidden_states = torch.cat(
159
+ [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
160
+ )
161
+
162
+ if attention_mask is not None:
163
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
164
+
165
+ # 2. Generate positional embeddings
166
+ image_rotary_emb = self.rope(hidden_states, fps=fps)
167
+ extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
168
+
169
+ # 3. Patchify input
170
+ p_t, p_h, p_w = self.config.patch_size
171
+ hidden_states = self.patch_embed(hidden_states)
172
+ hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
173
+
174
+ # 4. Timestep embeddings
175
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
176
+
177
+ return (
178
+ hidden_states,
179
+ temb,
180
+ embedded_timestep,
181
+ image_rotary_emb[0],
182
+ image_rotary_emb[1],
183
+ extra_pos_emb,
184
+ attention_mask,
185
+ )
186
+
187
+ @classmethod
188
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
189
+ num_latent_frames = rbln_config.num_latent_frames
190
+ latent_height = rbln_config.latent_height
191
+ latent_width = rbln_config.latent_width
192
+ return CosmosTransformer3DModelWrapper(
193
+ model=model,
194
+ num_latent_frames=num_latent_frames,
195
+ latent_height=latent_height,
196
+ latent_width=latent_width,
197
+ ).eval()
198
+
199
+ @classmethod
200
+ def update_rbln_config_using_pipe(
201
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
202
+ ) -> RBLNCosmosTransformer3DModelConfig:
203
+ rbln_config.transformer.num_latent_frames = (
204
+ rbln_config.transformer.num_frames - 1
205
+ ) // pipe.vae_scale_factor_temporal + 1
206
+ rbln_config.transformer.latent_height = rbln_config.transformer.height // pipe.vae_scale_factor_spatial
207
+ rbln_config.transformer.latent_width = rbln_config.transformer.width // pipe.vae_scale_factor_spatial
208
+ rbln_config.transformer.max_seq_len = pipe.text_encoder.config.n_positions
209
+ rbln_config.transformer.embedding_dim = pipe.text_encoder.encoder.embed_tokens.embedding_dim
210
+
211
+ return rbln_config
212
+
213
+ @classmethod
214
+ def save_torch_artifacts(
215
+ cls,
216
+ model: "PreTrainedModel",
217
+ save_dir_path: Path,
218
+ subfolder: str,
219
+ rbln_config: RBLNModelConfig,
220
+ ):
221
+ save_dict = {}
222
+ save_dict["rope"] = model.rope.state_dict()
223
+ if model.learnable_pos_embed is not None:
224
+ save_dict["learnable_pos_embed"] = model.learnable_pos_embed.state_dict()
225
+ save_dict["patch_embed"] = model.patch_embed.state_dict()
226
+ save_dict["time_embed"] = model.time_embed.state_dict()
227
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
228
+
229
+ @classmethod
230
+ def _update_rbln_config(
231
+ cls,
232
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
233
+ model: "PreTrainedModel",
234
+ model_config: "PretrainedConfig",
235
+ rbln_config: "RBLNCosmosTransformer3DModelConfig",
236
+ ) -> RBLNCosmosTransformer3DModelConfig:
237
+ p_t, p_h, p_w = model_config.patch_size
238
+ hidden_dim = (
239
+ (rbln_config.num_latent_frames // p_t)
240
+ * (rbln_config.latent_height // p_h)
241
+ * (rbln_config.latent_width // p_w)
242
+ )
243
+ attention_head_dim = model_config.attention_head_dim
244
+ hidden_size = model.config.num_attention_heads * model.config.attention_head_dim
245
+ input_info = [
246
+ (
247
+ "hidden_states",
248
+ [
249
+ rbln_config.batch_size,
250
+ hidden_dim,
251
+ hidden_size,
252
+ ],
253
+ "float32",
254
+ ),
255
+ (
256
+ "encoder_hidden_states",
257
+ [
258
+ rbln_config.batch_size,
259
+ rbln_config.max_seq_len,
260
+ rbln_config.embedding_dim,
261
+ ],
262
+ "float32",
263
+ ),
264
+ ("embedded_timestep", [rbln_config.batch_size, hidden_size], "float32"),
265
+ ("temb", [1, hidden_size * 3], "float32"),
266
+ ("image_rotary_emb_0", [hidden_dim, attention_head_dim], "float32"),
267
+ ("image_rotary_emb_1", [hidden_dim, attention_head_dim], "float32"),
268
+ ("extra_pos_emb", [rbln_config.batch_size, hidden_dim, hidden_size], "float32"),
269
+ ]
270
+
271
+ compile_config = RBLNCompileConfig(input_info=input_info)
272
+ rbln_config.set_compile_cfgs([compile_config])
273
+ return rbln_config
274
+
275
+ @classmethod
276
+ def _create_runtimes(
277
+ cls,
278
+ compiled_models: List[rebel.RBLNCompiledModel],
279
+ rbln_config: RBLNModelConfig,
280
+ ) -> List[rebel.Runtime]:
281
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
282
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
283
+
284
+ return [
285
+ rebel.Runtime(
286
+ compiled_model,
287
+ tensor_type="pt",
288
+ device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
289
+ activate_profiler=rbln_config.activate_profiler,
290
+ timeout=rbln_config.timeout,
291
+ )
292
+ for compiled_model in compiled_models
293
+ ]
294
+
295
+ def forward(
296
+ self,
297
+ hidden_states: torch.Tensor,
298
+ timestep: torch.Tensor,
299
+ encoder_hidden_states: torch.Tensor,
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ fps: Optional[int] = None,
302
+ condition_mask: Optional[torch.Tensor] = None,
303
+ padding_mask: Optional[torch.Tensor] = None,
304
+ return_dict: bool = True,
305
+ ):
306
+ """
307
+ Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
+
309
+ Args:
310
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
311
+ timestep (torch.Tensor): Current denoising step.
312
+ encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
313
+ fps: (Optional[int]): Frames per second for the video being generated.
314
+ condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
315
+ padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
316
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
317
+
318
+ Returns:
319
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
320
+ """
321
+ (
322
+ hidden_states,
323
+ temb,
324
+ embedded_timestep,
325
+ image_rotary_emb_0,
326
+ image_rotary_emb_1,
327
+ extra_pos_emb,
328
+ attention_mask,
329
+ ) = self.compute_embedding(hidden_states, timestep, attention_mask, fps, condition_mask, padding_mask)
330
+
331
+ hidden_states = self.model[0].forward(
332
+ hidden_states,
333
+ encoder_hidden_states,
334
+ embedded_timestep,
335
+ temb,
336
+ image_rotary_emb_0,
337
+ image_rotary_emb_1,
338
+ extra_pos_emb,
339
+ )
340
+
341
+ if not return_dict:
342
+ return (hidden_states,)
343
+ else:
344
+ return Transformer2DModelOutput(sample=hidden_states)
@@ -0,0 +1,191 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
16
+
17
+ import torch
18
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
19
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
20
+ from transformers import PretrainedConfig
21
+
22
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
23
+ from ....modeling import RBLNModel
24
+ from ....utils.logging import get_logger
25
+ from ...configurations import RBLNSD3Transformer2DModelConfig
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
30
+
31
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class SD3Transformer2DModelWrapper(torch.nn.Module):
37
+ def __init__(self, model: "SD3Transformer2DModel") -> None:
38
+ super().__init__()
39
+ self.model = model
40
+
41
+ def forward(
42
+ self,
43
+ hidden_states: torch.FloatTensor,
44
+ encoder_hidden_states: torch.FloatTensor = None,
45
+ pooled_projections: torch.FloatTensor = None,
46
+ timestep: torch.LongTensor = None,
47
+ # need controlnet support?
48
+ block_controlnet_hidden_states: List = None,
49
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
50
+ return_dict: bool = True,
51
+ ):
52
+ return self.model(
53
+ hidden_states=hidden_states,
54
+ encoder_hidden_states=encoder_hidden_states,
55
+ pooled_projections=pooled_projections,
56
+ timestep=timestep,
57
+ return_dict=False,
58
+ )
59
+
60
+
61
+ class RBLNSD3Transformer2DModel(RBLNModel):
62
+ """
63
+ RBLN implementation of SD3Transformer2DModel for diffusion models like Stable Diffusion 3.
64
+
65
+ The SD3Transformer2DModel takes text and/or image embeddings from encoders (like CLIP) and
66
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
67
+
68
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
69
+ the library implements for all its models.
70
+ """
71
+
72
+ hf_library_name = "diffusers"
73
+ auto_model_class = SD3Transformer2DModel
74
+ _output_class = Transformer2DModelOutput
75
+
76
+ def __post_init__(self, **kwargs):
77
+ super().__post_init__(**kwargs)
78
+
79
+ @classmethod
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
+ return SD3Transformer2DModelWrapper(model).eval()
82
+
83
+ @classmethod
84
+ def update_rbln_config_using_pipe(
85
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
86
+ ) -> "RBLNDiffusionMixinConfig":
87
+ if rbln_config.sample_size is None:
88
+ if rbln_config.image_size is not None:
89
+ rbln_config.transformer.sample_size = (
90
+ rbln_config.image_size[0] // pipe.vae_scale_factor,
91
+ rbln_config.image_size[1] // pipe.vae_scale_factor,
92
+ )
93
+ else:
94
+ rbln_config.transformer.sample_size = pipe.default_sample_size
95
+
96
+ prompt_embed_length = pipe.tokenizer_max_length + rbln_config.max_seq_len
97
+ rbln_config.transformer.prompt_embed_length = prompt_embed_length
98
+ return rbln_config
99
+
100
+ @classmethod
101
+ def _update_rbln_config(
102
+ cls,
103
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
104
+ model: "PreTrainedModel",
105
+ model_config: "PretrainedConfig",
106
+ rbln_config: RBLNSD3Transformer2DModelConfig,
107
+ ) -> RBLNSD3Transformer2DModelConfig:
108
+ if rbln_config.sample_size is None:
109
+ rbln_config.sample_size = model_config.sample_size
110
+
111
+ if isinstance(rbln_config.sample_size, int):
112
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
113
+
114
+ input_info = [
115
+ (
116
+ "hidden_states",
117
+ [
118
+ rbln_config.batch_size,
119
+ model_config.in_channels,
120
+ rbln_config.sample_size[0],
121
+ rbln_config.sample_size[1],
122
+ ],
123
+ "float32",
124
+ ),
125
+ (
126
+ "encoder_hidden_states",
127
+ [
128
+ rbln_config.batch_size,
129
+ rbln_config.prompt_embed_length,
130
+ model_config.joint_attention_dim,
131
+ ],
132
+ "float32",
133
+ ),
134
+ (
135
+ "pooled_projections",
136
+ [
137
+ rbln_config.batch_size,
138
+ model_config.pooled_projection_dim,
139
+ ],
140
+ "float32",
141
+ ),
142
+ ("timestep", [rbln_config.batch_size], "float32"),
143
+ ]
144
+
145
+ compile_config = RBLNCompileConfig(input_info=input_info)
146
+ rbln_config.set_compile_cfgs([compile_config])
147
+ return rbln_config
148
+
149
+ @property
150
+ def compiled_batch_size(self):
151
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
152
+
153
+ def forward(
154
+ self,
155
+ hidden_states: torch.FloatTensor,
156
+ encoder_hidden_states: torch.FloatTensor = None,
157
+ pooled_projections: torch.FloatTensor = None,
158
+ timestep: torch.LongTensor = None,
159
+ block_controlnet_hidden_states: List = None,
160
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
161
+ return_dict: bool = True,
162
+ **kwargs,
163
+ ):
164
+ """
165
+ Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
+
167
+ Args:
168
+ hidden_states (torch.FloatTensor): The currently predicted image embeddings.
169
+ encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
170
+ pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
171
+ timestep (torch.LongTensor): Current denoising step.
172
+ return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
173
+
174
+ Returns:
175
+ (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
176
+ """
177
+ sample_batch_size = hidden_states.size()[0]
178
+ compiled_batch_size = self.compiled_batch_size
179
+ if sample_batch_size != compiled_batch_size and (
180
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
181
+ ):
182
+ raise ValueError(
183
+ f"Mismatch between transformer's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
184
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
185
+ "Adjust the batch size of transformer during compilation.\n\n"
186
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/stable_diffusion_3.html#important-batch-size-configuration-for-guidance-scale"
187
+ )
188
+
189
+ return super().forward(
190
+ hidden_states, encoder_hidden_states, pooled_projections, timestep, return_dict=return_dict
191
+ )
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .unet_2d_condition import RBLNUNet2DConditionModel
16
+ from .unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModel