optimum-rbln 0.9.3__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +0 -12
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +2 -4
  4. optimum/rbln/diffusers/__init__.py +0 -12
  5. optimum/rbln/diffusers/configurations/__init__.py +0 -3
  6. optimum/rbln/diffusers/configurations/models/__init__.py +0 -2
  7. optimum/rbln/diffusers/configurations/pipelines/__init__.py +0 -3
  8. optimum/rbln/diffusers/models/__init__.py +3 -17
  9. optimum/rbln/diffusers/models/autoencoders/__init__.py +0 -1
  10. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  11. optimum/rbln/diffusers/models/autoencoders/vae.py +8 -27
  12. optimum/rbln/diffusers/models/controlnet.py +1 -16
  13. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -16
  14. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -16
  15. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -14
  16. optimum/rbln/diffusers/models/unets/__init__.py +0 -1
  17. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -17
  18. optimum/rbln/diffusers/pipelines/__init__.py +0 -4
  19. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +0 -20
  20. optimum/rbln/modeling.py +45 -20
  21. optimum/rbln/modeling_base.py +1 -0
  22. optimum/rbln/transformers/configuration_generic.py +27 -0
  23. optimum/rbln/transformers/modeling_attention_utils.py +109 -242
  24. optimum/rbln/transformers/modeling_generic.py +61 -2
  25. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +2 -28
  26. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +5 -68
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +2 -23
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -86
  29. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +15 -42
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +2 -40
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +44 -5
  32. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +1 -6
  33. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +2 -6
  34. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +9 -17
  35. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +12 -36
  36. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +0 -17
  37. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +0 -24
  38. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -17
  39. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -3
  40. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +8 -24
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +5 -3
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +24 -36
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -2
  44. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  45. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  46. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -13
  47. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
  48. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +3 -2
  49. optimum/rbln/transformers/models/resnet/configuration_resnet.py +0 -17
  50. optimum/rbln/transformers/models/resnet/modeling_resnet.py +0 -73
  51. optimum/rbln/transformers/models/roberta/modeling_roberta.py +0 -33
  52. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +4 -2
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +10 -34
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +1 -17
  55. optimum/rbln/transformers/models/swin/modeling_swin.py +1 -14
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  57. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +2 -16
  58. optimum/rbln/transformers/models/vit/modeling_vit.py +0 -19
  59. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +3 -15
  60. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +8 -60
  61. optimum/rbln/transformers/models/whisper/generation_whisper.py +14 -48
  62. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  63. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +0 -43
  64. optimum/rbln/transformers/utils/rbln_quantization.py +0 -9
  65. optimum/rbln/utils/depreacate_utils.py +16 -0
  66. optimum/rbln/utils/hub.py +3 -14
  67. optimum/rbln/utils/runtime_utils.py +0 -32
  68. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +2 -2
  69. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +72 -79
  70. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +1 -1
  71. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +0 -67
  72. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +0 -59
  73. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +0 -114
  74. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +0 -275
  75. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +0 -201
  76. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +0 -15
  77. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +0 -46
  78. optimum/rbln/utils/deprecation.py +0 -213
  79. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.3rc0.dist-info}/entry_points.txt +0 -0
  80. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -186,16 +186,12 @@ _import_structure = {
186
186
  "diffusers": [
187
187
  "RBLNAutoencoderKL",
188
188
  "RBLNAutoencoderKLConfig",
189
- "RBLNAutoencoderKLTemporalDecoder",
190
- "RBLNAutoencoderKLTemporalDecoderConfig",
191
189
  "RBLNAutoencoderKLCosmos",
192
190
  "RBLNAutoencoderKLCosmosConfig",
193
191
  "RBLNAutoPipelineForImage2Image",
194
192
  "RBLNAutoPipelineForInpainting",
195
193
  "RBLNAutoPipelineForText2Image",
196
194
  "RBLNControlNetModel",
197
- "RBLNUNetSpatioTemporalConditionModel",
198
- "RBLNStableVideoDiffusionPipeline",
199
195
  "RBLNControlNetModelConfig",
200
196
  "RBLNCosmosTextToWorldPipeline",
201
197
  "RBLNCosmosVideoToWorldPipeline",
@@ -254,8 +250,6 @@ _import_structure = {
254
250
  "RBLNUNet2DConditionModelConfig",
255
251
  "RBLNVQModel",
256
252
  "RBLNVQModelConfig",
257
- "RBLNUNetSpatioTemporalConditionModelConfig",
258
- "RBLNStableVideoDiffusionPipelineConfig",
259
253
  ],
260
254
  }
261
255
 
@@ -266,8 +260,6 @@ if TYPE_CHECKING:
266
260
  RBLNAutoencoderKLConfig,
267
261
  RBLNAutoencoderKLCosmos,
268
262
  RBLNAutoencoderKLCosmosConfig,
269
- RBLNAutoencoderKLTemporalDecoder,
270
- RBLNAutoencoderKLTemporalDecoderConfig,
271
263
  RBLNAutoPipelineForImage2Image,
272
264
  RBLNAutoPipelineForInpainting,
273
265
  RBLNAutoPipelineForText2Image,
@@ -326,12 +318,8 @@ if TYPE_CHECKING:
326
318
  RBLNStableDiffusionXLInpaintPipelineConfig,
327
319
  RBLNStableDiffusionXLPipeline,
328
320
  RBLNStableDiffusionXLPipelineConfig,
329
- RBLNStableVideoDiffusionPipeline,
330
- RBLNStableVideoDiffusionPipelineConfig,
331
321
  RBLNUNet2DConditionModel,
332
322
  RBLNUNet2DConditionModelConfig,
333
- RBLNUNetSpatioTemporalConditionModel,
334
- RBLNUNetSpatioTemporalConditionModelConfig,
335
323
  RBLNVQModel,
336
324
  RBLNVQModelConfig,
337
325
  )
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.3'
32
- __version_tuple__ = version_tuple = (0, 9, 3)
31
+ __version__ = version = '0.9.3rc0'
32
+ __version_tuple__ = version_tuple = (0, 9, 3, 'rc0')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -24,7 +24,7 @@ import torch
24
24
  from packaging.version import Version
25
25
 
26
26
  from .__version__ import __version__
27
- from .utils.deprecation import warn_deprecated_npu
27
+ from .utils.depreacate_utils import warn_deprecated_npu
28
28
  from .utils.logging import get_logger
29
29
  from .utils.runtime_utils import ContextRblnConfig
30
30
 
@@ -528,7 +528,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
528
528
  ]
529
529
  submodules: List[str] = []
530
530
  subclass_non_save_attributes = []
531
- _allow_no_compile_cfgs = False
532
531
 
533
532
  def initialize_submodule_config(
534
533
  self,
@@ -809,8 +808,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
809
808
  or len(self._compile_cfgs) == 0
810
809
  or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
811
810
  ):
812
- if not self._allow_no_compile_cfgs:
813
- raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
811
+ raise RuntimeError("`compile_cfgs` must be set before freezing.")
814
812
 
815
813
  for submodule_name in self.submodules:
816
814
  submodule_config = getattr(self, submodule_name, None)
@@ -57,9 +57,6 @@ _import_structure = {
57
57
  "RBLNSD3Transformer2DModelConfig",
58
58
  "RBLNUNet2DConditionModelConfig",
59
59
  "RBLNVQModelConfig",
60
- "RBLNUNetSpatioTemporalConditionModelConfig",
61
- "RBLNStableVideoDiffusionPipelineConfig",
62
- "RBLNAutoencoderKLTemporalDecoderConfig",
63
60
  ],
64
61
  "pipelines": [
65
62
  "RBLNAutoPipelineForImage2Image",
@@ -89,17 +86,14 @@ _import_structure = {
89
86
  "RBLNStableDiffusion3Pipeline",
90
87
  "RBLNStableDiffusion3Img2ImgPipeline",
91
88
  "RBLNStableDiffusion3InpaintPipeline",
92
- "RBLNStableVideoDiffusionPipeline",
93
89
  ],
94
90
  "models": [
95
91
  "RBLNAutoencoderKL",
96
92
  "RBLNAutoencoderKLCosmos",
97
93
  "RBLNUNet2DConditionModel",
98
- "RBLNUNetSpatioTemporalConditionModel",
99
94
  "RBLNControlNetModel",
100
95
  "RBLNCosmosTransformer3DModel",
101
96
  "RBLNSD3Transformer2DModel",
102
- "RBLNAutoencoderKLTemporalDecoder",
103
97
  "RBLNPriorTransformer",
104
98
  "RBLNVQModel",
105
99
  ],
@@ -112,7 +106,6 @@ if TYPE_CHECKING:
112
106
  from .configurations import (
113
107
  RBLNAutoencoderKLConfig,
114
108
  RBLNAutoencoderKLCosmosConfig,
115
- RBLNAutoencoderKLTemporalDecoderConfig,
116
109
  RBLNControlNetModelConfig,
117
110
  RBLNCosmosTextToWorldPipelineConfig,
118
111
  RBLNCosmosTransformer3DModelConfig,
@@ -139,22 +132,18 @@ if TYPE_CHECKING:
139
132
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
140
133
  RBLNStableDiffusionXLInpaintPipelineConfig,
141
134
  RBLNStableDiffusionXLPipelineConfig,
142
- RBLNStableVideoDiffusionPipelineConfig,
143
135
  RBLNUNet2DConditionModelConfig,
144
- RBLNUNetSpatioTemporalConditionModelConfig,
145
136
  RBLNVQModelConfig,
146
137
  )
147
138
  from .modeling_diffusers import RBLNDiffusionMixin
148
139
  from .models import (
149
140
  RBLNAutoencoderKL,
150
141
  RBLNAutoencoderKLCosmos,
151
- RBLNAutoencoderKLTemporalDecoder,
152
142
  RBLNControlNetModel,
153
143
  RBLNCosmosTransformer3DModel,
154
144
  RBLNPriorTransformer,
155
145
  RBLNSD3Transformer2DModel,
156
146
  RBLNUNet2DConditionModel,
157
- RBLNUNetSpatioTemporalConditionModel,
158
147
  RBLNVQModel,
159
148
  )
160
149
  from .pipelines import (
@@ -185,7 +174,6 @@ if TYPE_CHECKING:
185
174
  RBLNStableDiffusionXLImg2ImgPipeline,
186
175
  RBLNStableDiffusionXLInpaintPipeline,
187
176
  RBLNStableDiffusionXLPipeline,
188
- RBLNStableVideoDiffusionPipeline,
189
177
  )
190
178
  else:
191
179
  import sys
@@ -1,13 +1,11 @@
1
1
  from .models import (
2
2
  RBLNAutoencoderKLConfig,
3
3
  RBLNAutoencoderKLCosmosConfig,
4
- RBLNAutoencoderKLTemporalDecoderConfig,
5
4
  RBLNControlNetModelConfig,
6
5
  RBLNCosmosTransformer3DModelConfig,
7
6
  RBLNPriorTransformerConfig,
8
7
  RBLNSD3Transformer2DModelConfig,
9
8
  RBLNUNet2DConditionModelConfig,
10
- RBLNUNetSpatioTemporalConditionModelConfig,
11
9
  RBLNVQModelConfig,
12
10
  )
13
11
  from .pipelines import (
@@ -33,5 +31,4 @@ from .pipelines import (
33
31
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
34
32
  RBLNStableDiffusionXLInpaintPipelineConfig,
35
33
  RBLNStableDiffusionXLPipelineConfig,
36
- RBLNStableVideoDiffusionPipelineConfig,
37
34
  )
@@ -1,10 +1,8 @@
1
1
  from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
2
2
  from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
3
- from .configuration_autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoderConfig
4
3
  from .configuration_controlnet import RBLNControlNetModelConfig
5
4
  from .configuration_prior_transformer import RBLNPriorTransformerConfig
6
5
  from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
7
6
  from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
8
7
  from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
9
- from .configuration_unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModelConfig
10
8
  from .configuration_vq_model import RBLNVQModelConfig
@@ -29,6 +29,3 @@ from .configuration_stable_diffusion_xl import (
29
29
  RBLNStableDiffusionXLInpaintPipelineConfig,
30
30
  RBLNStableDiffusionXLPipelineConfig,
31
31
  )
32
- from .configuration_stable_video_diffusion import (
33
- RBLNStableVideoDiffusionPipelineConfig,
34
- )
@@ -22,11 +22,9 @@ _import_structure = {
22
22
  "RBLNAutoencoderKL",
23
23
  "RBLNAutoencoderKLCosmos",
24
24
  "RBLNVQModel",
25
- "RBLNAutoencoderKLTemporalDecoder",
26
25
  ],
27
26
  "unets": [
28
27
  "RBLNUNet2DConditionModel",
29
- "RBLNUNetSpatioTemporalConditionModel",
30
28
  ],
31
29
  "controlnet": ["RBLNControlNetModel"],
32
30
  "transformers": [
@@ -37,22 +35,10 @@ _import_structure = {
37
35
  }
38
36
 
39
37
  if TYPE_CHECKING:
40
- from .autoencoders import (
41
- RBLNAutoencoderKL,
42
- RBLNAutoencoderKLCosmos,
43
- RBLNAutoencoderKLTemporalDecoder,
44
- RBLNVQModel,
45
- )
38
+ from .autoencoders import RBLNAutoencoderKL, RBLNAutoencoderKLCosmos, RBLNVQModel
46
39
  from .controlnet import RBLNControlNetModel
47
- from .transformers import (
48
- RBLNCosmosTransformer3DModel,
49
- RBLNPriorTransformer,
50
- RBLNSD3Transformer2DModel,
51
- )
52
- from .unets import (
53
- RBLNUNet2DConditionModel,
54
- RBLNUNetSpatioTemporalConditionModel,
55
- )
40
+ from .transformers import RBLNCosmosTransformer3DModel, RBLNPriorTransformer, RBLNSD3Transformer2DModel
41
+ from .unets import RBLNUNet2DConditionModel
56
42
  else:
57
43
  import sys
58
44
 
@@ -14,5 +14,4 @@
14
14
 
15
15
  from .autoencoder_kl import RBLNAutoencoderKL
16
16
  from .autoencoder_kl_cosmos import RBLNAutoencoderKLCosmos
17
- from .autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoder
18
17
  from .vq_model import RBLNVQModel
@@ -68,7 +68,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
68
68
  self.image_size = self.rbln_config.image_size
69
69
 
70
70
  @classmethod
71
- def _wrap_model_if_needed(
71
+ def wrap_model_if_needed(
72
72
  cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
73
  ) -> torch.nn.Module:
74
74
  decoder_model = _VAECosmosDecoder(model)
@@ -98,7 +98,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
98
98
 
99
99
  compiled_models = {}
100
100
  if rbln_config.uses_encoder:
101
- encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
101
+ encoder_model, decoder_model = cls.wrap_model_if_needed(model, rbln_config)
102
102
  enc_compiled_model = cls.compile(
103
103
  encoder_model,
104
104
  rbln_compile_config=rbln_config.compile_cfgs[0],
@@ -107,7 +107,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
107
107
  )
108
108
  compiled_models["encoder"] = enc_compiled_model
109
109
  else:
110
- decoder_model = cls._wrap_model_if_needed(model, rbln_config)
110
+ decoder_model = cls.wrap_model_if_needed(model, rbln_config)
111
111
  dec_compiled_model = cls.compile(
112
112
  decoder_model,
113
113
  rbln_compile_config=rbln_config.compile_cfgs[-1],
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, List, Union
15
+ from typing import TYPE_CHECKING, List
16
16
 
17
17
  import torch
18
18
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
@@ -21,7 +21,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
21
21
 
22
22
 
23
23
  if TYPE_CHECKING:
24
- from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, VQModel
25
25
 
26
26
 
27
27
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
@@ -67,37 +67,18 @@ class _VAEDecoder(torch.nn.Module):
67
67
  return vae_out
68
68
 
69
69
 
70
- class _VAETemporalDecoder(torch.nn.Module):
71
- def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
72
- super().__init__()
73
- self.vae = vae
74
- self.num_frames = None
75
-
76
- def forward(self, z):
77
- vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
78
- return vae_out
79
-
80
-
81
70
  class _VAEEncoder(torch.nn.Module):
82
- def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
71
+ def __init__(self, vae: "AutoencoderKL"):
83
72
  super().__init__()
84
73
  self.vae = vae
85
74
 
86
75
  def encode(self, x: torch.FloatTensor, return_dict: bool = True):
87
- if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
88
- if self.use_tiling and (
89
- x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
90
- ):
91
- return self.tiled_encode(x, return_dict=return_dict)
92
-
93
- if self.use_slicing and x.shape[0] > 1:
94
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
95
- h = torch.cat(encoded_slices)
96
- else:
97
- h = self.encoder(x)
98
- if self.quant_conv is not None:
99
- h = self.quant_conv(h)
76
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
77
+ return self.tiled_encode(x, return_dict=return_dict)
100
78
 
79
+ if self.use_slicing and x.shape[0] > 1:
80
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
81
+ h = torch.cat(encoded_slices)
101
82
  else:
102
83
  h = self.encoder(x)
103
84
  if self.quant_conv is not None:
@@ -118,7 +118,7 @@ class RBLNControlNetModel(RBLNModel):
118
118
  )
119
119
 
120
120
  @classmethod
121
- def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
121
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
122
122
  use_encoder_hidden_states = False
123
123
  for down_block in model.down_blocks:
124
124
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -219,21 +219,6 @@ class RBLNControlNetModel(RBLNModel):
219
219
  return_dict: bool = True,
220
220
  **kwargs,
221
221
  ):
222
- """
223
- Forward pass for the RBLN-optimized ControlNetModel.
224
-
225
- Args:
226
- sample (torch.FloatTensor): The noisy input tensor.
227
- timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
228
- encoder_hidden_states (torch.Tensor): The encoder hidden states.
229
- controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
230
- conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
231
- added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
232
- return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
233
-
234
- Returns:
235
- (Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
236
- """
237
222
  sample_batch_size = sample.size()[0]
238
223
  compiled_batch_size = self.compiled_batch_size
239
224
  if sample_batch_size != compiled_batch_size and (
@@ -77,7 +77,7 @@ class RBLNPriorTransformer(RBLNModel):
77
77
  self.clip_std = artifacts["clip_std"]
78
78
 
79
79
  @classmethod
80
- def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return _PriorTransformer(model).eval()
82
82
 
83
83
  @classmethod
@@ -128,27 +128,13 @@ class RBLNPriorTransformer(RBLNModel):
128
128
 
129
129
  def forward(
130
130
  self,
131
- hidden_states: torch.Tensor,
131
+ hidden_states,
132
132
  timestep: Union[torch.Tensor, float, int],
133
133
  proj_embedding: torch.Tensor,
134
134
  encoder_hidden_states: Optional[torch.Tensor] = None,
135
135
  attention_mask: Optional[torch.Tensor] = None,
136
136
  return_dict: bool = True,
137
137
  ):
138
- """
139
- Forward pass for the RBLN-optimized PriorTransformer.
140
-
141
- Args:
142
- hidden_states (torch.Tensor): The currently predicted image embeddings.
143
- timestep (Union[torch.Tensor, float, int]): Current denoising step.
144
- proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
145
- encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
146
- attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
147
- return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
148
-
149
- Returns:
150
- (Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
151
- """
152
138
  # Convert timestep(long) and attention_mask(bool) to float
153
139
  return super().forward(
154
140
  hidden_states,
@@ -185,7 +185,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
185
185
  )
186
186
 
187
187
  @classmethod
188
- def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
188
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
189
189
  num_latent_frames = rbln_config.num_latent_frames
190
190
  latent_height = rbln_config.latent_height
191
191
  latent_width = rbln_config.latent_width
@@ -303,21 +303,6 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
303
303
  padding_mask: Optional[torch.Tensor] = None,
304
304
  return_dict: bool = True,
305
305
  ):
306
- """
307
- Forward pass for the RBLN-optimized CosmosTransformer3DModel.
308
-
309
- Args:
310
- hidden_states (torch.Tensor): The currently predicted image embeddings.
311
- timestep (torch.Tensor): Current denoising step.
312
- encoder_hidden_states (torch.Tensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
313
- fps: (Optional[int]): Frames per second for the video being generated.
314
- condition_mask (Optional[torch.Tensor]): Tensor of condition mask.
315
- padding_mask (Optional[torch.Tensor]): Tensor of padding mask.
316
- return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
317
-
318
- Returns:
319
- (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
320
- """
321
306
  (
322
307
  hidden_states,
323
308
  temb,
@@ -77,7 +77,7 @@ class RBLNSD3Transformer2DModel(RBLNModel):
77
77
  super().__post_init__(**kwargs)
78
78
 
79
79
  @classmethod
80
- def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
80
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
81
  return SD3Transformer2DModelWrapper(model).eval()
82
82
 
83
83
  @classmethod
@@ -161,19 +161,6 @@ class RBLNSD3Transformer2DModel(RBLNModel):
161
161
  return_dict: bool = True,
162
162
  **kwargs,
163
163
  ):
164
- """
165
- Forward pass for the RBLN-optimized SD3Transformer2DModel.
166
-
167
- Args:
168
- hidden_states (torch.FloatTensor): The currently predicted image embeddings.
169
- encoder_hidden_states (torch.FloatTensor): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
170
- pooled_projections (torch.FloatTensor): Embeddings projected from the embeddings of input conditions.
171
- timestep (torch.LongTensor): Current denoising step.
172
- return_dict (bool): Whether or not to return a [`~diffusers.models.modeling_output.Transformer2DModelOutput`] instead of a plain tuple.
173
-
174
- Returns:
175
- (Union[`~diffusers.models.modeling_output.Transformer2DModelOutput`, Tuple])
176
- """
177
164
  sample_batch_size = hidden_states.size()[0]
178
165
  compiled_batch_size = self.compiled_batch_size
179
166
  if sample_batch_size != compiled_batch_size and (
@@ -13,4 +13,3 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from .unet_2d_condition import RBLNUNet2DConditionModel
16
- from .unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModel
@@ -171,7 +171,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
171
171
  self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
172
172
 
173
173
  @classmethod
174
- def _wrap_model_if_needed(
174
+ def wrap_model_if_needed(
175
175
  cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
176
176
  ) -> torch.nn.Module:
177
177
  if model.config.addition_embed_type == "text_time":
@@ -349,22 +349,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
349
349
  return_dict: bool = True,
350
350
  **kwargs,
351
351
  ) -> Union[UNet2DConditionOutput, Tuple]:
352
- """
353
- Forward pass for the RBLN-optimized UNet2DConditionModel.
354
-
355
- Args:
356
- sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
357
- timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
358
- encoder_hidden_states (torch.Tensor): The encoder hidden states.
359
- added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
360
- if specified are added to the embeddings that are passed along to the UNet blocks.
361
- down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
362
- mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
363
- return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
364
-
365
- Returns:
366
- (Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
367
- """
368
352
  sample_batch_size = sample.size()[0]
369
353
  compiled_batch_size = self.compiled_batch_size
370
354
  if sample_batch_size != compiled_batch_size and (
@@ -59,9 +59,6 @@ _import_structure = {
59
59
  "RBLNStableDiffusion3Img2ImgPipeline",
60
60
  "RBLNStableDiffusion3InpaintPipeline",
61
61
  ],
62
- "stable_video_diffusion": [
63
- "RBLNStableVideoDiffusionPipeline",
64
- ],
65
62
  }
66
63
  if TYPE_CHECKING:
67
64
  from .auto_pipeline import (
@@ -101,7 +98,6 @@ if TYPE_CHECKING:
101
98
  RBLNStableDiffusionXLInpaintPipeline,
102
99
  RBLNStableDiffusionXLPipeline,
103
100
  )
104
- from .stable_video_diffusion import RBLNStableVideoDiffusionPipeline
105
101
  else:
106
102
  import sys
107
103
 
@@ -96,26 +96,6 @@ class RBLNMultiControlNetModel(RBLNModel):
96
96
  guess_mode: bool = False,
97
97
  return_dict: bool = True,
98
98
  ):
99
- """
100
- Forward pass for the RBLN-optimized MultiControlNetModel.
101
-
102
- This method processes multiple ControlNet models in sequence, applying each one to the input sample
103
- with its corresponding conditioning image and scale factor. The outputs from all ControlNets are
104
- merged by addition to produce the final control signals.
105
-
106
- Args:
107
- sample (torch.FloatTensor): The noisy input tensor.
108
- timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
109
- encoder_hidden_states (torch.Tensor): The encoder hidden states from the text encoder.
110
- controlnet_cond (List[torch.Tensor]): A list of conditional input tensors, one for each ControlNet model.
111
- conditioning_scale (List[float]): A list of scale factors for each ControlNet output. Each scale
112
- controls the strength of the corresponding ControlNet's influence on the generation.
113
- return_dict (bool): Whether or not to return a dictionary instead of a plain tuple. Currently,
114
- this method always returns a tuple regardless of this parameter.
115
-
116
- Returns:
117
- (Tuple[List[torch.Tensor], torch.Tensor])
118
- """
119
99
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
120
100
  down_samples, mid_sample = controlnet(
121
101
  sample=sample.contiguous(),
optimum/rbln/modeling.py CHANGED
@@ -34,6 +34,49 @@ if TYPE_CHECKING:
34
34
  logger = get_logger(__name__)
35
35
 
36
36
 
37
+ def _get_dtype(
38
+ cls,
39
+ dtype: Optional[Union[str, torch.dtype, dict]],
40
+ config: PretrainedConfig,
41
+ ) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
42
+ dtype_orig = None
43
+
44
+ if dtype is not None:
45
+ if isinstance(dtype, str):
46
+ if dtype == "auto":
47
+ if hasattr(config, "dtype") and config.dtype is not None:
48
+ dtype = config.dtype
49
+ else:
50
+ dtype = torch.get_default_dtype()
51
+ elif hasattr(torch, dtype):
52
+ dtype = getattr(torch, dtype)
53
+ config.dtype = dtype
54
+ elif isinstance(dtype, torch.dtype):
55
+ config.dtype = dtype
56
+ elif isinstance(dtype, dict):
57
+ for key, curr_dtype in dtype.items():
58
+ if hasattr(config, key):
59
+ value = getattr(config, key)
60
+ curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
61
+ value.dtype = curr_dtype
62
+ # main torch dtype for modules that aren't part of any sub-config
63
+ dtype = dtype.get("")
64
+ dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
65
+ config.dtype = dtype
66
+ if dtype is None:
67
+ dtype = torch.float32
68
+ else:
69
+ raise ValueError(f"Invalid dtype: {dtype}")
70
+
71
+ dtype_orig = cls._set_default_dtype(dtype)
72
+ else:
73
+ # Use default dtype
74
+ default_dtype = torch.get_default_dtype()
75
+ config.dtype = default_dtype
76
+
77
+ return config, dtype, dtype_orig
78
+
79
+
37
80
  class RBLNModel(RBLNBaseModel):
38
81
  @classmethod
39
82
  def update_kwargs(cls, kwargs):
@@ -54,16 +97,13 @@ class RBLNModel(RBLNBaseModel):
54
97
  pass
55
98
 
56
99
  @classmethod
57
- def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
100
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
58
101
  # Wrap the model if needed.
59
102
  return model
60
103
 
61
104
  @classmethod
62
105
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
63
- if rbln_config._allow_no_compile_cfgs:
64
- return {}
65
-
66
- model = cls._wrap_model_if_needed(model, rbln_config)
106
+ model = cls.wrap_model_if_needed(model, rbln_config)
67
107
  rbln_compile_config = rbln_config.compile_cfgs[0]
68
108
  compiled_model = cls.compile(
69
109
  model,
@@ -73,18 +113,6 @@ class RBLNModel(RBLNBaseModel):
73
113
  )
74
114
  return compiled_model
75
115
 
76
- @classmethod
77
- def _update_rbln_config(
78
- cls,
79
- preprocessors: Optional[Any],
80
- model: Optional["PreTrainedModel"] = None,
81
- model_config: Optional["PretrainedConfig"] = None,
82
- rbln_config: Optional[RBLNModelConfig] = None,
83
- ) -> RBLNModelConfig:
84
- # Default implementation: return config as-is
85
- # Subclasses should override to set compile_cfgs if needed
86
- return rbln_config
87
-
88
116
  @classmethod
89
117
  def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
90
118
  return model
@@ -249,9 +277,6 @@ class RBLNModel(RBLNBaseModel):
249
277
  compiled_models: List[rebel.RBLNCompiledModel],
250
278
  rbln_config: RBLNModelConfig,
251
279
  ) -> List[rebel.Runtime]:
252
- if len(rbln_config.compile_cfgs) == 0:
253
- return []
254
-
255
280
  if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
256
281
  cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
257
282
 
@@ -71,6 +71,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
71
71
  self.rbln_config = rbln_config
72
72
  if not rbln_config.is_frozen():
73
73
  raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
74
+
74
75
  self.compiled_models = rbln_compiled_models
75
76
 
76
77
  # Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating