optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -12,17 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import importlib
16
15
  from typing import TYPE_CHECKING, Dict, Optional, Union
17
16
 
18
17
  import torch
19
18
  from diffusers import ControlNetModel
20
- from diffusers.models.controlnet import ControlNetOutput
19
+ from diffusers.models.controlnets.controlnet import ControlNetOutput
21
20
  from transformers import PretrainedConfig
22
21
 
23
22
  from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
23
  from ...modeling import RBLNModel
25
24
  from ...utils.logging import get_logger
25
+ from ...utils.model_utils import get_rbln_model_cls
26
26
  from ..configurations import RBLNControlNetModelConfig
27
27
  from ..modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
28
28
 
@@ -98,6 +98,15 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
98
98
 
99
99
 
100
100
  class RBLNControlNetModel(RBLNModel):
101
+ """
102
+ RBLN implementation of ControlNetModel for diffusion models.
103
+
104
+ This model is used to accelerate ControlNetModel models from diffusers library on RBLN NPUs.
105
+
106
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
107
+ the library implements for all its models.
108
+ """
109
+
101
110
  hf_library_name = "diffusers"
102
111
  auto_model_class = ControlNetModel
103
112
  output_class = ControlNetOutput
@@ -122,13 +131,10 @@ class RBLNControlNetModel(RBLNModel):
122
131
 
123
132
  @classmethod
124
133
  def update_rbln_config_using_pipe(
125
- cls,
126
- pipe: RBLNDiffusionMixin,
127
- rbln_config: "RBLNDiffusionMixinConfig",
128
- submodule_name: str,
134
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
129
135
  ) -> "RBLNDiffusionMixinConfig":
130
- rbln_vae_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.vae.__class__.__name__}")
131
- rbln_unet_cls = getattr(importlib.import_module("optimum.rbln"), f"RBLN{pipe.unet.__class__.__name__}")
136
+ rbln_vae_cls = get_rbln_model_cls(f"RBLN{pipe.vae.__class__.__name__}")
137
+ rbln_unet_cls = get_rbln_model_cls(f"RBLN{pipe.unet.__class__.__name__}")
132
138
 
133
139
  rbln_config.controlnet.max_seq_len = pipe.text_encoder.config.max_position_embeddings
134
140
  text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .prior_transformer import RBLNPriorTransformer
16
+ from .transformer_cosmos import RBLNCosmosTransformer3DModel
16
17
  from .transformer_sd3 import RBLNSD3Transformer2DModel
@@ -56,6 +56,16 @@ class _PriorTransformer(torch.nn.Module):
56
56
 
57
57
 
58
58
  class RBLNPriorTransformer(RBLNModel):
59
+ """
60
+ RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
61
+
62
+ The Prior Transformer takes text and/or image embeddings from encoders (like CLIP) and
63
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
64
+
65
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
66
+ the library implements for all its models.
67
+ """
68
+
59
69
  hf_library_name = "diffusers"
60
70
  auto_model_class = PriorTransformer
61
71
  _output_class = PriorTransformerOutput
@@ -0,0 +1,321 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, List, Optional, Union
17
+
18
+ import rebel
19
+ import torch
20
+ from diffusers import CosmosTransformer3DModel
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.models.transformers.transformer_cosmos import (
23
+ CosmosEmbedding,
24
+ CosmosLearnablePositionalEmbed,
25
+ CosmosPatchEmbed,
26
+ CosmosRotaryPosEmbed,
27
+ )
28
+ from torchvision import transforms
29
+
30
+ from ....configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNModelConfig
31
+ from ....modeling import RBLNModel
32
+ from ....utils.logging import get_logger
33
+ from ...configurations import RBLNCosmosTransformer3DModelConfig
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
38
+
39
+ from ...modeling_diffusers import RBLNCosmosTransformer3DModelConfig, RBLNDiffusionMixin, RBLNDiffusionMixinConfig
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+
45
+ class CosmosTransformer3DModelWrapper(torch.nn.Module):
46
+ def __init__(
47
+ self,
48
+ model: CosmosTransformer3DModel,
49
+ num_latent_frames: int = 16,
50
+ latent_height: int = 88,
51
+ latent_width: int = 160,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.model = model
55
+ self.num_latent_frames = num_latent_frames
56
+ self.latent_height = latent_height
57
+ self.latent_width = latent_width
58
+ self.p_t, self.p_h, self.p_w = model.config.patch_size
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ encoder_hidden_states: torch.Tensor,
64
+ embedded_timestep: torch.Tensor,
65
+ temb: torch.Tensor,
66
+ image_rotary_emb_0: torch.Tensor,
67
+ image_rotary_emb_1: torch.Tensor,
68
+ extra_pos_emb: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ return_dict: bool = False,
71
+ ):
72
+ image_rotary_emb = [image_rotary_emb_0, image_rotary_emb_1]
73
+ for block in self.model.transformer_blocks:
74
+ hidden_states = block(
75
+ hidden_states=hidden_states,
76
+ encoder_hidden_states=encoder_hidden_states,
77
+ embedded_timestep=embedded_timestep,
78
+ temb=temb,
79
+ image_rotary_emb=image_rotary_emb,
80
+ extra_pos_emb=extra_pos_emb,
81
+ attention_mask=attention_mask,
82
+ )
83
+ post_patch_num_frames = self.num_latent_frames // self.p_t
84
+ post_patch_height = self.latent_height // self.p_h
85
+ post_patch_width = self.latent_width // self.p_w
86
+ hidden_states = self.model.norm_out(hidden_states, embedded_timestep, temb)
87
+ hidden_states = self.model.proj_out(hidden_states)
88
+ hidden_states = hidden_states.unflatten(2, (self.p_h, self.p_w, self.p_t, -1))
89
+ hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
90
+ hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
91
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
92
+
93
+ return (hidden_states,)
94
+
95
+
96
+ class RBLNCosmosTransformer3DModel(RBLNModel):
97
+ """RBLN wrapper for the Cosmos Transformer model."""
98
+
99
+ hf_library_name = "diffusers"
100
+ auto_model_class = CosmosTransformer3DModel
101
+
102
+ def __post_init__(self, **kwargs):
103
+ super().__post_init__(**kwargs)
104
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
105
+
106
+ hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
107
+ patch_embed_in_channels = (
108
+ self.config.in_channels + 1 if self.config.concat_padding_mask else self.config.in_channels
109
+ )
110
+ self.rope = CosmosRotaryPosEmbed(
111
+ hidden_size=self.config.attention_head_dim,
112
+ max_size=self.config.max_size,
113
+ patch_size=self.config.patch_size,
114
+ rope_scale=self.config.rope_scale,
115
+ )
116
+ self.rope.load_state_dict(artifacts["rope"])
117
+ if artifacts["learnable_pos_embed"] is None:
118
+ self.learnable_pos_embed = None
119
+ else:
120
+ self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
121
+ hidden_size=hidden_size,
122
+ max_size=self.config.max_size,
123
+ patch_size=self.config.patch_size,
124
+ )
125
+ self.learnable_pos_embed.load_state_dict(artifacts["learnable_pos_embed"])
126
+ self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, self.config.patch_size, bias=False)
127
+ self.patch_embed.load_state_dict(artifacts["patch_embed"])
128
+ self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
129
+ self.time_embed.load_state_dict(artifacts["time_embed"])
130
+
131
+ def compute_embedding(
132
+ self,
133
+ hidden_states: torch.Tensor,
134
+ timestep: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ fps: Optional[int] = None,
137
+ condition_mask: Optional[torch.Tensor] = None,
138
+ padding_mask: Optional[torch.Tensor] = None,
139
+ ):
140
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
141
+
142
+ # 1. Concatenate padding mask if needed & prepare attention mask
143
+ if condition_mask is not None:
144
+ hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
145
+
146
+ if self.config.concat_padding_mask:
147
+ padding_mask = transforms.functional.resize(
148
+ padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
149
+ )
150
+ hidden_states = torch.cat(
151
+ [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
152
+ )
153
+
154
+ if attention_mask is not None:
155
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
156
+
157
+ # 2. Generate positional embeddings
158
+ image_rotary_emb = self.rope(hidden_states, fps=fps)
159
+ extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
160
+
161
+ # 3. Patchify input
162
+ p_t, p_h, p_w = self.config.patch_size
163
+ hidden_states = self.patch_embed(hidden_states)
164
+ hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
165
+
166
+ # 4. Timestep embeddings
167
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
168
+
169
+ return (
170
+ hidden_states,
171
+ temb,
172
+ embedded_timestep,
173
+ image_rotary_emb[0],
174
+ image_rotary_emb[1],
175
+ extra_pos_emb,
176
+ attention_mask,
177
+ )
178
+
179
+ @classmethod
180
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
181
+ num_latent_frames = rbln_config.num_latent_frames
182
+ latent_height = rbln_config.latent_height
183
+ latent_width = rbln_config.latent_width
184
+ return CosmosTransformer3DModelWrapper(
185
+ model=model,
186
+ num_latent_frames=num_latent_frames,
187
+ latent_height=latent_height,
188
+ latent_width=latent_width,
189
+ ).eval()
190
+
191
+ @classmethod
192
+ def update_rbln_config_using_pipe(
193
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
194
+ ) -> RBLNCosmosTransformer3DModelConfig:
195
+ rbln_config.transformer.num_latent_frames = (
196
+ rbln_config.transformer.num_frames - 1
197
+ ) // pipe.vae_scale_factor_temporal + 1
198
+ rbln_config.transformer.latent_height = rbln_config.transformer.height // pipe.vae_scale_factor_spatial
199
+ rbln_config.transformer.latent_width = rbln_config.transformer.width // pipe.vae_scale_factor_spatial
200
+ rbln_config.transformer.max_seq_len = pipe.text_encoder.config.n_positions
201
+ rbln_config.transformer.embedding_dim = pipe.text_encoder.encoder.embed_tokens.embedding_dim
202
+
203
+ return rbln_config
204
+
205
+ @classmethod
206
+ def save_torch_artifacts(
207
+ cls,
208
+ model: "PreTrainedModel",
209
+ save_dir_path: Path,
210
+ subfolder: str,
211
+ rbln_config: RBLNModelConfig,
212
+ ):
213
+ save_dict = {}
214
+ save_dict["rope"] = model.rope.state_dict()
215
+ if model.learnable_pos_embed is not None:
216
+ save_dict["learnable_pos_embed"] = model.learnable_pos_embed.state_dict()
217
+ save_dict["patch_embed"] = model.patch_embed.state_dict()
218
+ save_dict["time_embed"] = model.time_embed.state_dict()
219
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
220
+
221
+ @classmethod
222
+ def _update_rbln_config(
223
+ cls,
224
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
225
+ model: "PreTrainedModel",
226
+ model_config: "PretrainedConfig",
227
+ rbln_config: "RBLNCosmosTransformer3DModelConfig",
228
+ ) -> RBLNCosmosTransformer3DModelConfig:
229
+ p_t, p_h, p_w = model_config.patch_size
230
+ hidden_dim = (
231
+ (rbln_config.num_latent_frames // p_t)
232
+ * (rbln_config.latent_height // p_h)
233
+ * (rbln_config.latent_width // p_w)
234
+ )
235
+ attention_head_dim = model_config.attention_head_dim
236
+ hidden_size = model.config.num_attention_heads * model.config.attention_head_dim
237
+ input_info = [
238
+ (
239
+ "hidden_states",
240
+ [
241
+ rbln_config.batch_size,
242
+ hidden_dim,
243
+ hidden_size,
244
+ ],
245
+ "float32",
246
+ ),
247
+ (
248
+ "encoder_hidden_states",
249
+ [
250
+ rbln_config.batch_size,
251
+ rbln_config.max_seq_len,
252
+ rbln_config.embedding_dim,
253
+ ],
254
+ "float32",
255
+ ),
256
+ ("embedded_timestep", [rbln_config.batch_size, hidden_size], "float32"),
257
+ ("temb", [1, hidden_size * 3], "float32"),
258
+ ("image_rotary_emb_0", [hidden_dim, attention_head_dim], "float32"),
259
+ ("image_rotary_emb_1", [hidden_dim, attention_head_dim], "float32"),
260
+ ("extra_pos_emb", [rbln_config.batch_size, hidden_dim, hidden_size], "float32"),
261
+ ]
262
+
263
+ compile_config = RBLNCompileConfig(input_info=input_info)
264
+ rbln_config.set_compile_cfgs([compile_config])
265
+ return rbln_config
266
+
267
+ @classmethod
268
+ def _create_runtimes(
269
+ cls,
270
+ compiled_models: List[rebel.RBLNCompiledModel],
271
+ rbln_config: RBLNModelConfig,
272
+ ) -> List[rebel.Runtime]:
273
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
274
+ cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
275
+
276
+ return [
277
+ rebel.Runtime(
278
+ compiled_model,
279
+ tensor_type="pt",
280
+ device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
281
+ activate_profiler=rbln_config.activate_profiler,
282
+ timeout=120,
283
+ )
284
+ for compiled_model in compiled_models
285
+ ]
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ timestep: torch.Tensor,
291
+ encoder_hidden_states: torch.Tensor,
292
+ attention_mask: Optional[torch.Tensor] = None,
293
+ fps: Optional[int] = None,
294
+ condition_mask: Optional[torch.Tensor] = None,
295
+ padding_mask: Optional[torch.Tensor] = None,
296
+ return_dict: bool = True,
297
+ ):
298
+ (
299
+ hidden_states,
300
+ temb,
301
+ embedded_timestep,
302
+ image_rotary_emb_0,
303
+ image_rotary_emb_1,
304
+ extra_pos_emb,
305
+ attention_mask,
306
+ ) = self.compute_embedding(hidden_states, timestep, attention_mask, fps, condition_mask, padding_mask)
307
+
308
+ hidden_states = self.model[0].forward(
309
+ hidden_states,
310
+ encoder_hidden_states,
311
+ embedded_timestep,
312
+ temb,
313
+ image_rotary_emb_0,
314
+ image_rotary_emb_1,
315
+ extra_pos_emb,
316
+ )
317
+
318
+ if not return_dict:
319
+ return (hidden_states,)
320
+ else:
321
+ return Transformer2DModelOutput(sample=hidden_states)
@@ -59,6 +59,8 @@ class SD3Transformer2DModelWrapper(torch.nn.Module):
59
59
 
60
60
 
61
61
  class RBLNSD3Transformer2DModel(RBLNModel):
62
+ """RBLN wrapper for the Stable Diffusion 3 MMDiT Transformer model."""
63
+
62
64
  hf_library_name = "diffusers"
63
65
  auto_model_class = SD3Transformer2DModel
64
66
  _output_class = Transformer2DModelOutput
@@ -140,6 +140,13 @@ class _UNet_Kandinsky(torch.nn.Module):
140
140
 
141
141
 
142
142
  class RBLNUNet2DConditionModel(RBLNModel):
143
+ """
144
+ Configuration class for RBLN UNet2DCondition models.
145
+
146
+ This class inherits from RBLNModelConfig and provides specific configuration options
147
+ for UNet2DCondition models used in diffusion-based image generation.
148
+ """
149
+
143
150
  hf_library_name = "diffusers"
144
151
  auto_model_class = UNet2DConditionModel
145
152
  _rbln_config_class = RBLNUNet2DConditionModelConfig
@@ -178,7 +185,10 @@ class RBLNUNet2DConditionModel(RBLNModel):
178
185
  rbln_config: RBLNUNet2DConditionModelConfig,
179
186
  image_size: Optional[Tuple[int, int]] = None,
180
187
  ) -> Tuple[int, int]:
181
- scale_factor = pipe.movq_scale_factor if hasattr(pipe, "movq_scale_factor") else pipe.vae_scale_factor
188
+ if hasattr(pipe, "movq"):
189
+ scale_factor = 2 ** (len(pipe.movq.config.block_out_channels) - 1)
190
+ else:
191
+ scale_factor = pipe.vae_scale_factor
182
192
 
183
193
  if image_size is None:
184
194
  if "Img2Img" in pipe.__class__.__name__:
@@ -25,6 +25,11 @@ _import_structure = {
25
25
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
26
26
  "RBLNStableDiffusionXLControlNetPipeline",
27
27
  ],
28
+ "cosmos": [
29
+ "RBLNCosmosTextToWorldPipeline",
30
+ "RBLNCosmosVideoToWorldPipeline",
31
+ "RBLNCosmosSafetyChecker",
32
+ ],
28
33
  "kandinsky2_2": [
29
34
  "RBLNKandinskyV22CombinedPipeline",
30
35
  "RBLNKandinskyV22Img2ImgCombinedPipeline",
@@ -58,6 +63,11 @@ if TYPE_CHECKING:
58
63
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
59
64
  RBLNStableDiffusionXLControlNetPipeline,
60
65
  )
66
+ from .cosmos import (
67
+ RBLNCosmosSafetyChecker,
68
+ RBLNCosmosTextToWorldPipeline,
69
+ RBLNCosmosVideoToWorldPipeline,
70
+ )
61
71
  from .kandinsky2_2 import (
62
72
  RBLNKandinskyV22CombinedPipeline,
63
73
  RBLNKandinskyV22Img2ImgCombinedPipeline,
@@ -14,7 +14,7 @@
14
14
 
15
15
  import os
16
16
  from pathlib import Path
17
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
17
+ from typing import Any, Dict, List, Optional, Union
18
18
 
19
19
  import torch
20
20
  from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
@@ -24,9 +24,6 @@ from ....utils.logging import get_logger
24
24
  from ...models.controlnet import RBLNControlNetModel
25
25
 
26
26
 
27
- if TYPE_CHECKING:
28
- pass
29
-
30
27
  logger = get_logger(__name__)
31
28
 
32
29
 
@@ -49,6 +49,13 @@ logger = get_logger(__name__)
49
49
 
50
50
 
51
51
  class RBLNStableDiffusionControlNetPipeline(RBLNDiffusionMixin, StableDiffusionControlNetPipeline):
52
+ """
53
+ RBLN-accelerated implementation of Stable Diffusion pipeline with ControlNet for guided text-to-image generation.
54
+
55
+ This pipeline compiles Stable Diffusion and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
56
+ inference for generating images with precise structural control using conditioning inputs like edges, depth, or poses.
57
+ """
58
+
52
59
  original_class = StableDiffusionControlNetPipeline
53
60
  _rbln_config_class = RBLNStableDiffusionControlNetPipelineConfig
54
61
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
@@ -47,6 +47,13 @@ logger = logging.get_logger(__name__)
47
47
 
48
48
 
49
49
  class RBLNStableDiffusionControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionControlNetImg2ImgPipeline):
50
+ """
51
+ RBLN-accelerated implementation of Stable Diffusion pipeline with ControlNet for guided image-to-image generation.
52
+
53
+ This pipeline compiles Stable Diffusion and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
54
+ inference for transforming input images with precise structural control and conditioning guidance.
55
+ """
56
+
50
57
  original_class = StableDiffusionControlNetImg2ImgPipeline
51
58
  _submodules = ["text_encoder", "unet", "vae", "controlnet"]
52
59
  _rbln_config_class = RBLNStableDiffusionControlNetImg2ImgPipelineConfig
@@ -47,6 +47,13 @@ logger = logging.get_logger(__name__)
47
47
 
48
48
 
49
49
  class RBLNStableDiffusionXLControlNetPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetPipeline):
50
+ """
51
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline with ControlNet for high-resolution guided text-to-image generation.
52
+
53
+ This pipeline compiles Stable Diffusion XL and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
54
+ inference for generating high-quality images with precise structural control and enhanced detail preservation.
55
+ """
56
+
50
57
  original_class = StableDiffusionXLControlNetPipeline
51
58
  _rbln_config_class = RBLNStableDiffusionXLControlNetPipelineConfig
52
59
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
@@ -47,6 +47,13 @@ logger = logging.get_logger(__name__)
47
47
 
48
48
 
49
49
  class RBLNStableDiffusionXLControlNetImg2ImgPipeline(RBLNDiffusionMixin, StableDiffusionXLControlNetImg2ImgPipeline):
50
+ """
51
+ RBLN-accelerated implementation of Stable Diffusion XL pipeline with ControlNet for high-resolution guided image-to-image generation.
52
+
53
+ This pipeline compiles Stable Diffusion XL and ControlNet models to run efficiently on RBLN NPUs, enabling high-performance
54
+ inference for transforming input images with precise structural control and enhanced quality preservation.
55
+ """
56
+
50
57
  original_class = StableDiffusionXLControlNetImg2ImgPipeline
51
58
  _rbln_config_class = RBLNStableDiffusionXLControlNetImg2ImgPipelineConfig
52
59
  _submodules = ["text_encoder", "text_encoder_2", "unet", "vae", "controlnet"]
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .cosmos_guardrail import RBLNCosmosSafetyChecker
16
+ from .pipeline_cosmos_text2world import RBLNCosmosTextToWorldPipeline
17
+ from .pipeline_cosmos_video2world import RBLNCosmosVideoToWorldPipeline
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
+
20
+
21
+ class RBLNVideoSafetyModelConfig(RBLNModelConfig):
22
+ """
23
+ Configuration class for RBLN Video Content Safety Filter.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ batch_size: Optional[int] = None,
29
+ input_size: Optional[int] = None,
30
+ image_size: Optional[Tuple[int, int]] = None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.batch_size = batch_size or 1
35
+ self.input_size = input_size or 1152
36
+
37
+
38
+ class RBLNRetinaFaceFilterConfig(RBLNModelConfig):
39
+ """
40
+ Configuration class for RBLN Retina Face Filter.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ batch_size: Optional[int] = None,
46
+ image_size: Optional[Tuple[int, int]] = None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.batch_size = batch_size or 1
51
+ self.image_size = image_size or (704, 1280)
52
+
53
+
54
+ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
55
+ """
56
+ Configuration class for RBLN Cosmos Safety Checker.
57
+ """
58
+
59
+ submodules = ["aegis", "video_safety_model", "face_blur_filter", "siglip_encoder"]
60
+
61
+ def __init__(
62
+ self,
63
+ aegis: Optional[RBLNModelConfig] = None,
64
+ video_safety_model: Optional[RBLNModelConfig] = None,
65
+ face_blur_filter: Optional[RBLNModelConfig] = None,
66
+ siglip_encoder: Optional[RBLNSiglipVisionModelConfig] = None,
67
+ *,
68
+ batch_size: Optional[int] = None,
69
+ image_size: Optional[Tuple[int, int]] = None,
70
+ height: Optional[int] = None,
71
+ width: Optional[int] = None,
72
+ **kwargs: Dict[str, Any],
73
+ ):
74
+ super().__init__(**kwargs)
75
+ if height is not None and width is not None:
76
+ image_size = (height, width)
77
+
78
+ self.aegis = self.init_submodule_config(RBLNModelConfig, aegis)
79
+ self.siglip_encoder = self.init_submodule_config(
80
+ RBLNSiglipVisionModelConfig,
81
+ siglip_encoder,
82
+ batch_size=batch_size,
83
+ image_size=(384, 384),
84
+ )
85
+
86
+ self.video_safety_model = self.init_submodule_config(
87
+ RBLNVideoSafetyModelConfig,
88
+ video_safety_model,
89
+ batch_size=batch_size,
90
+ input_size=1152,
91
+ )
92
+ self.face_blur_filter = self.init_submodule_config(
93
+ RBLNRetinaFaceFilterConfig,
94
+ face_blur_filter,
95
+ batch_size=batch_size,
96
+ image_size=image_size,
97
+ )
98
+
99
+
100
+ RBLNAutoConfig.register(RBLNVideoSafetyModelConfig)
101
+ RBLNAutoConfig.register(RBLNRetinaFaceFilterConfig)
102
+ RBLNAutoConfig.register(RBLNCosmosSafetyCheckerConfig)