optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -91,6 +91,10 @@ _import_structure = {
91
91
  "RBLNGemmaModel",
92
92
  "RBLNGemmaModelConfig",
93
93
  "RBLNGemmaForCausalLM",
94
+ "RBLNGemma2ForCausalLM",
95
+ "RBLNGemma2ForCausalLMConfig",
96
+ "RBLNGemma2Model",
97
+ "RBLNGemma2ModelConfig",
94
98
  "RBLNGemmaForCausalLMConfig",
95
99
  "RBLNGemma3ForCausalLM",
96
100
  "RBLNGemma3ForCausalLMConfig",
@@ -100,6 +104,8 @@ _import_structure = {
100
104
  "RBLNGPT2ModelConfig",
101
105
  "RBLNGPT2LMHeadModel",
102
106
  "RBLNGPT2LMHeadModelConfig",
107
+ "RBLNGptOssForCausalLM",
108
+ "RBLNGptOssForCausalLMConfig",
103
109
  "RBLNGroundingDinoDecoder",
104
110
  "RBLNGroundingDinoDecoderConfig",
105
111
  "RBLNGroundingDinoForObjectDetection",
@@ -140,14 +146,24 @@ _import_structure = {
140
146
  "RBLNPixtralVisionModelConfig",
141
147
  "RBLNPhiModel",
142
148
  "RBLNPhiModelConfig",
149
+ "RBLNPaliGemmaForConditionalGeneration",
150
+ "RBLNPaliGemmaForConditionalGenerationConfig",
151
+ "RBLNPaliGemmaModel",
152
+ "RBLNPaliGemmaModelConfig",
143
153
  "RBLNQwen2ForCausalLM",
144
154
  "RBLNQwen2ForCausalLMConfig",
145
155
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
146
156
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
147
157
  "RBLNQwen2_5_VLForConditionalGeneration",
148
158
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
159
+ "RBLNQwen3MoeForCausalLM",
160
+ "RBLNQwen3MoeForCausalLMConfig",
161
+ "RBLNQwen2_5_VLModel",
162
+ "RBLNQwen2_5_VLModelConfig",
149
163
  "RBLNQwen2Model",
150
164
  "RBLNQwen2ModelConfig",
165
+ "RBLNQwen2MoeForCausalLM",
166
+ "RBLNQwen2MoeForCausalLMConfig",
151
167
  "RBLNQwen3ForCausalLM",
152
168
  "RBLNQwen3ForCausalLMConfig",
153
169
  "RBLNQwen3Model",
@@ -156,6 +172,8 @@ _import_structure = {
156
172
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
157
173
  "RBLNQwen2VLForConditionalGeneration",
158
174
  "RBLNQwen2VLForConditionalGenerationConfig",
175
+ "RBLNQwen2VLModel",
176
+ "RBLNQwen2VLModelConfig",
159
177
  "RBLNResNetForImageClassification",
160
178
  "RBLNResNetForImageClassificationConfig",
161
179
  "RBLNRobertaForMaskedLM",
@@ -186,12 +204,16 @@ _import_structure = {
186
204
  "diffusers": [
187
205
  "RBLNAutoencoderKL",
188
206
  "RBLNAutoencoderKLConfig",
207
+ "RBLNAutoencoderKLTemporalDecoder",
208
+ "RBLNAutoencoderKLTemporalDecoderConfig",
189
209
  "RBLNAutoencoderKLCosmos",
190
210
  "RBLNAutoencoderKLCosmosConfig",
191
211
  "RBLNAutoPipelineForImage2Image",
192
212
  "RBLNAutoPipelineForInpainting",
193
213
  "RBLNAutoPipelineForText2Image",
194
214
  "RBLNControlNetModel",
215
+ "RBLNUNetSpatioTemporalConditionModel",
216
+ "RBLNStableVideoDiffusionPipeline",
195
217
  "RBLNControlNetModelConfig",
196
218
  "RBLNCosmosTextToWorldPipeline",
197
219
  "RBLNCosmosVideoToWorldPipeline",
@@ -250,6 +272,8 @@ _import_structure = {
250
272
  "RBLNUNet2DConditionModelConfig",
251
273
  "RBLNVQModel",
252
274
  "RBLNVQModelConfig",
275
+ "RBLNUNetSpatioTemporalConditionModelConfig",
276
+ "RBLNStableVideoDiffusionPipelineConfig",
253
277
  ],
254
278
  }
255
279
 
@@ -260,6 +284,8 @@ if TYPE_CHECKING:
260
284
  RBLNAutoencoderKLConfig,
261
285
  RBLNAutoencoderKLCosmos,
262
286
  RBLNAutoencoderKLCosmosConfig,
287
+ RBLNAutoencoderKLTemporalDecoder,
288
+ RBLNAutoencoderKLTemporalDecoderConfig,
263
289
  RBLNAutoPipelineForImage2Image,
264
290
  RBLNAutoPipelineForInpainting,
265
291
  RBLNAutoPipelineForText2Image,
@@ -318,8 +344,12 @@ if TYPE_CHECKING:
318
344
  RBLNStableDiffusionXLInpaintPipelineConfig,
319
345
  RBLNStableDiffusionXLPipeline,
320
346
  RBLNStableDiffusionXLPipelineConfig,
347
+ RBLNStableVideoDiffusionPipeline,
348
+ RBLNStableVideoDiffusionPipelineConfig,
321
349
  RBLNUNet2DConditionModel,
322
350
  RBLNUNet2DConditionModelConfig,
351
+ RBLNUNetSpatioTemporalConditionModel,
352
+ RBLNUNetSpatioTemporalConditionModelConfig,
323
353
  RBLNVQModel,
324
354
  RBLNVQModelConfig,
325
355
  )
@@ -382,6 +412,10 @@ if TYPE_CHECKING:
382
412
  RBLNDPTForDepthEstimationConfig,
383
413
  RBLNExaoneForCausalLM,
384
414
  RBLNExaoneForCausalLMConfig,
415
+ RBLNGemma2ForCausalLM,
416
+ RBLNGemma2ForCausalLMConfig,
417
+ RBLNGemma2Model,
418
+ RBLNGemma2ModelConfig,
385
419
  RBLNGemma3ForCausalLM,
386
420
  RBLNGemma3ForCausalLMConfig,
387
421
  RBLNGemma3ForConditionalGeneration,
@@ -394,6 +428,8 @@ if TYPE_CHECKING:
394
428
  RBLNGPT2LMHeadModelConfig,
395
429
  RBLNGPT2Model,
396
430
  RBLNGPT2ModelConfig,
431
+ RBLNGptOssForCausalLM,
432
+ RBLNGptOssForCausalLMConfig,
397
433
  RBLNGroundingDinoDecoder,
398
434
  RBLNGroundingDinoDecoderConfig,
399
435
  RBLNGroundingDinoEncoder,
@@ -424,6 +460,10 @@ if TYPE_CHECKING:
424
460
  RBLNOPTForCausalLMConfig,
425
461
  RBLNOPTModel,
426
462
  RBLNOPTModelConfig,
463
+ RBLNPaliGemmaForConditionalGeneration,
464
+ RBLNPaliGemmaForConditionalGenerationConfig,
465
+ RBLNPaliGemmaModel,
466
+ RBLNPaliGemmaModelConfig,
427
467
  RBLNPegasusForConditionalGeneration,
428
468
  RBLNPegasusForConditionalGenerationConfig,
429
469
  RBLNPegasusModel,
@@ -438,18 +478,26 @@ if TYPE_CHECKING:
438
478
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
439
479
  RBLNQwen2_5_VLForConditionalGeneration,
440
480
  RBLNQwen2_5_VLForConditionalGenerationConfig,
481
+ RBLNQwen2_5_VLModel,
482
+ RBLNQwen2_5_VLModelConfig,
441
483
  RBLNQwen2ForCausalLM,
442
484
  RBLNQwen2ForCausalLMConfig,
443
485
  RBLNQwen2Model,
444
486
  RBLNQwen2ModelConfig,
487
+ RBLNQwen2MoeForCausalLM,
488
+ RBLNQwen2MoeForCausalLMConfig,
445
489
  RBLNQwen2VisionTransformerPretrainedModel,
446
490
  RBLNQwen2VisionTransformerPretrainedModelConfig,
447
491
  RBLNQwen2VLForConditionalGeneration,
448
492
  RBLNQwen2VLForConditionalGenerationConfig,
493
+ RBLNQwen2VLModel,
494
+ RBLNQwen2VLModelConfig,
449
495
  RBLNQwen3ForCausalLM,
450
496
  RBLNQwen3ForCausalLMConfig,
451
497
  RBLNQwen3Model,
452
498
  RBLNQwen3ModelConfig,
499
+ RBLNQwen3MoeForCausalLM,
500
+ RBLNQwen3MoeForCausalLMConfig,
453
501
  RBLNResNetForImageClassification,
454
502
  RBLNResNetForImageClassificationConfig,
455
503
  RBLNRobertaForMaskedLM,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.3rc0'
32
- __version_tuple__ = version_tuple = (0, 9, 3, 'rc0')
31
+ __version__ = version = '0.9.5a4'
32
+ __version_tuple__ = version_tuple = (0, 9, 5, 'a4')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -24,7 +24,7 @@ import torch
24
24
  from packaging.version import Version
25
25
 
26
26
  from .__version__ import __version__
27
- from .utils.depreacate_utils import warn_deprecated_npu
27
+ from .utils.deprecation import deprecate_kwarg, warn_deprecated_npu
28
28
  from .utils.logging import get_logger
29
29
  from .utils.runtime_utils import ContextRblnConfig
30
30
 
@@ -92,7 +92,7 @@ class RBLNCompileConfig:
92
92
  and isinstance(item[0], str) # name
93
93
  and isinstance(item[1], (tuple, list)) # shape
94
94
  and all(isinstance(x, int) for x in item[1])
95
- and isinstance(item[2], str) # dtype
95
+ and (isinstance(item[2], str) or isinstance(item[2], torch.dtype)) # dtype
96
96
  for item in input_info
97
97
  )
98
98
 
@@ -117,9 +117,14 @@ class RBLNCompileConfig:
117
117
  return self
118
118
 
119
119
  def get_dummy_inputs(
120
- self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
120
+ self,
121
+ fill=0,
122
+ static_tensors: Optional[Dict[str, torch.Tensor]] = None,
123
+ meta_tensor_names: Optional[List[str]] = None,
121
124
  ):
122
125
  dummy = []
126
+ static_tensors = static_tensors if static_tensors is not None else {}
127
+ meta_tensor_names = meta_tensor_names if meta_tensor_names is not None else []
123
128
  for name, shape, dtype in self.input_info:
124
129
  if name in static_tensors:
125
130
  tensor = static_tensors[name]
@@ -255,7 +260,7 @@ class RBLNAutoConfig:
255
260
  def load(
256
261
  path: str,
257
262
  passed_rbln_config: Optional["RBLNModelConfig"] = None,
258
- kwargs: Optional[Dict[str, Any]] = {},
263
+ kwargs: Optional[Dict[str, Any]] = None,
259
264
  return_unused_kwargs: bool = False,
260
265
  ) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
261
266
  """
@@ -269,6 +274,8 @@ class RBLNAutoConfig:
269
274
  Returns:
270
275
  RBLNModelConfig: The loaded RBLNModelConfig.
271
276
  """
277
+ if kwargs is None:
278
+ kwargs = {}
272
279
  cls, config_file = load_config(path)
273
280
 
274
281
  rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
@@ -517,8 +524,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
517
524
  non_save_attributes = [
518
525
  "_frozen",
519
526
  "_runtime_options",
520
- "torch_dtype",
521
527
  "npu",
528
+ "dtype",
522
529
  "tensor_parallel_size",
523
530
  "create_runtimes",
524
531
  "device",
@@ -528,6 +535,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
528
535
  ]
529
536
  submodules: List[str] = []
530
537
  subclass_non_save_attributes = []
538
+ _allow_no_compile_cfgs = False
531
539
 
532
540
  def initialize_submodule_config(
533
541
  self,
@@ -642,6 +650,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
642
650
 
643
651
  super().__setattr__(key, value)
644
652
 
653
+ @deprecate_kwarg(
654
+ old_name="_torch_dtype",
655
+ new_name="dtype",
656
+ version="0.12.0",
657
+ deprecated_type=torch.dtype,
658
+ value_replacer=RBLNCompileConfig.normalize_dtype,
659
+ raise_if_greater_or_equal_version=False,
660
+ )
645
661
  def __init__(
646
662
  self,
647
663
  cls_name: Optional[str] = None,
@@ -653,8 +669,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
653
669
  tensor_parallel_size: Optional[int] = None,
654
670
  timeout: Optional[int] = None,
655
671
  optimum_rbln_version: Optional[str] = None,
656
- _torch_dtype: Optional[str] = None,
657
- _compile_cfgs: List[RBLNCompileConfig] = [],
672
+ dtype: Optional[Union[str, torch.dtype]] = None,
673
+ _compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
658
674
  *,
659
675
  optimize_host_memory: Optional[bool] = None,
660
676
  **kwargs: Any,
@@ -672,7 +688,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
672
688
  tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
673
689
  timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
674
690
  optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
675
- _torch_dtype (Optional[str]): The data type to use for the model.
691
+ dtype (Optional[Union[str, torch.dtype]]): The data type to use for the model.
676
692
  _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
677
693
  kwargs: Additional keyword arguments.
678
694
 
@@ -702,12 +718,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
702
718
  self.npu = npu
703
719
  self.tensor_parallel_size = tensor_parallel_size
704
720
 
705
- self._torch_dtype = _torch_dtype or "float32"
721
+ if dtype is not None and isinstance(dtype, torch.dtype):
722
+ dtype = RBLNCompileConfig.normalize_dtype(dtype)
723
+ self._dtype = dtype or "float32"
706
724
  self.optimum_rbln_version = optimum_rbln_version
707
725
  if self.optimum_rbln_version is None:
708
726
  self.optimum_rbln_version = __version__
709
727
 
710
- self._compile_cfgs: List[RBLNCompileConfig] = _compile_cfgs
728
+ compile_cfgs = _compile_cfgs if _compile_cfgs is not None else []
729
+ self._compile_cfgs: List[RBLNCompileConfig] = compile_cfgs
711
730
 
712
731
  if not isinstance(self._compile_cfgs, list):
713
732
  raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
@@ -734,14 +753,24 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
734
753
 
735
754
  @property
736
755
  def torch_dtype(self):
737
- return getattr(torch, self._torch_dtype)
756
+ logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
757
+ return self.dtype
738
758
 
739
759
  @torch_dtype.setter
740
760
  def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
741
- if isinstance(torch_dtype, torch.dtype):
742
- torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
761
+ logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
762
+ self.dtype = torch_dtype
763
+
764
+ @property
765
+ def dtype(self):
766
+ return getattr(torch, self._dtype)
767
+
768
+ @dtype.setter
769
+ def dtype(self, dtype: Union[str, torch.dtype]):
770
+ if isinstance(dtype, torch.dtype):
771
+ dtype = RBLNCompileConfig.normalize_dtype(dtype)
743
772
 
744
- self._torch_dtype = torch_dtype
773
+ self._dtype = dtype
745
774
 
746
775
  @property
747
776
  def rbln_model_cls_name(self) -> str:
@@ -765,10 +794,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
765
794
  if isinstance(value, RBLNSerializableConfigProtocol):
766
795
  # Convert nested RBLNModelConfig to its serializable form
767
796
  serializable_map[key] = value._prepare_for_serialization()
797
+ elif key == "_dtype":
798
+ serializable_map["dtype"] = value
799
+ elif isinstance(value, list) and all(isinstance(item, RBLNSerializableConfigProtocol) for item in value):
800
+ serializable_map[key] = [item._prepare_for_serialization() for item in value]
768
801
  elif key == "_compile_cfgs":
769
802
  serializable_map[key] = [cfg.asdict() for cfg in value]
770
803
  else:
771
804
  serializable_map[key] = value
805
+
772
806
  return serializable_map
773
807
 
774
808
  def __repr__(self):
@@ -808,25 +842,20 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
808
842
  or len(self._compile_cfgs) == 0
809
843
  or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
810
844
  ):
811
- raise RuntimeError("`compile_cfgs` must be set before freezing.")
845
+ if not self._allow_no_compile_cfgs:
846
+ raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
812
847
 
813
848
  for submodule_name in self.submodules:
814
849
  submodule_config = getattr(self, submodule_name, None)
815
850
  if not isinstance(submodule_config, RBLNModelConfig):
816
851
  raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
817
852
 
818
- if not submodule_config.is_frozen():
819
- raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
820
-
821
853
  self._frozen = True
822
854
 
823
855
  def is_frozen(self):
824
856
  return self._frozen
825
857
 
826
858
  def save(self, path: str):
827
- if not self._frozen:
828
- raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
829
-
830
859
  # save as json file without runtime attributes
831
860
  path = Path(path)
832
861
  if path.is_dir():
@@ -57,6 +57,9 @@ _import_structure = {
57
57
  "RBLNSD3Transformer2DModelConfig",
58
58
  "RBLNUNet2DConditionModelConfig",
59
59
  "RBLNVQModelConfig",
60
+ "RBLNUNetSpatioTemporalConditionModelConfig",
61
+ "RBLNStableVideoDiffusionPipelineConfig",
62
+ "RBLNAutoencoderKLTemporalDecoderConfig",
60
63
  ],
61
64
  "pipelines": [
62
65
  "RBLNAutoPipelineForImage2Image",
@@ -86,14 +89,17 @@ _import_structure = {
86
89
  "RBLNStableDiffusion3Pipeline",
87
90
  "RBLNStableDiffusion3Img2ImgPipeline",
88
91
  "RBLNStableDiffusion3InpaintPipeline",
92
+ "RBLNStableVideoDiffusionPipeline",
89
93
  ],
90
94
  "models": [
91
95
  "RBLNAutoencoderKL",
92
96
  "RBLNAutoencoderKLCosmos",
93
97
  "RBLNUNet2DConditionModel",
98
+ "RBLNUNetSpatioTemporalConditionModel",
94
99
  "RBLNControlNetModel",
95
100
  "RBLNCosmosTransformer3DModel",
96
101
  "RBLNSD3Transformer2DModel",
102
+ "RBLNAutoencoderKLTemporalDecoder",
97
103
  "RBLNPriorTransformer",
98
104
  "RBLNVQModel",
99
105
  ],
@@ -106,6 +112,7 @@ if TYPE_CHECKING:
106
112
  from .configurations import (
107
113
  RBLNAutoencoderKLConfig,
108
114
  RBLNAutoencoderKLCosmosConfig,
115
+ RBLNAutoencoderKLTemporalDecoderConfig,
109
116
  RBLNControlNetModelConfig,
110
117
  RBLNCosmosTextToWorldPipelineConfig,
111
118
  RBLNCosmosTransformer3DModelConfig,
@@ -132,18 +139,22 @@ if TYPE_CHECKING:
132
139
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
133
140
  RBLNStableDiffusionXLInpaintPipelineConfig,
134
141
  RBLNStableDiffusionXLPipelineConfig,
142
+ RBLNStableVideoDiffusionPipelineConfig,
135
143
  RBLNUNet2DConditionModelConfig,
144
+ RBLNUNetSpatioTemporalConditionModelConfig,
136
145
  RBLNVQModelConfig,
137
146
  )
138
147
  from .modeling_diffusers import RBLNDiffusionMixin
139
148
  from .models import (
140
149
  RBLNAutoencoderKL,
141
150
  RBLNAutoencoderKLCosmos,
151
+ RBLNAutoencoderKLTemporalDecoder,
142
152
  RBLNControlNetModel,
143
153
  RBLNCosmosTransformer3DModel,
144
154
  RBLNPriorTransformer,
145
155
  RBLNSD3Transformer2DModel,
146
156
  RBLNUNet2DConditionModel,
157
+ RBLNUNetSpatioTemporalConditionModel,
147
158
  RBLNVQModel,
148
159
  )
149
160
  from .pipelines import (
@@ -174,6 +185,7 @@ if TYPE_CHECKING:
174
185
  RBLNStableDiffusionXLImg2ImgPipeline,
175
186
  RBLNStableDiffusionXLInpaintPipeline,
176
187
  RBLNStableDiffusionXLPipeline,
188
+ RBLNStableVideoDiffusionPipeline,
177
189
  )
178
190
  else:
179
191
  import sys
@@ -1,11 +1,13 @@
1
1
  from .models import (
2
2
  RBLNAutoencoderKLConfig,
3
3
  RBLNAutoencoderKLCosmosConfig,
4
+ RBLNAutoencoderKLTemporalDecoderConfig,
4
5
  RBLNControlNetModelConfig,
5
6
  RBLNCosmosTransformer3DModelConfig,
6
7
  RBLNPriorTransformerConfig,
7
8
  RBLNSD3Transformer2DModelConfig,
8
9
  RBLNUNet2DConditionModelConfig,
10
+ RBLNUNetSpatioTemporalConditionModelConfig,
9
11
  RBLNVQModelConfig,
10
12
  )
11
13
  from .pipelines import (
@@ -31,4 +33,5 @@ from .pipelines import (
31
33
  RBLNStableDiffusionXLImg2ImgPipelineConfig,
32
34
  RBLNStableDiffusionXLInpaintPipelineConfig,
33
35
  RBLNStableDiffusionXLPipelineConfig,
36
+ RBLNStableVideoDiffusionPipelineConfig,
34
37
  )
@@ -1,8 +1,10 @@
1
1
  from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
2
2
  from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
3
+ from .configuration_autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoderConfig
3
4
  from .configuration_controlnet import RBLNControlNetModelConfig
4
5
  from .configuration_prior_transformer import RBLNPriorTransformerConfig
5
6
  from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
6
7
  from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
7
8
  from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
9
+ from .configuration_unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModelConfig
8
10
  from .configuration_vq_model import RBLNVQModelConfig
@@ -0,0 +1,67 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNAutoencoderKLTemporalDecoderConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ batch_size: Optional[int] = None,
24
+ sample_size: Optional[Tuple[int, int]] = None,
25
+ uses_encoder: Optional[bool] = None,
26
+ num_frames: Optional[int] = None,
27
+ decode_chunk_size: Optional[int] = None,
28
+ vae_scale_factor: Optional[float] = None,
29
+ **kwargs: Any,
30
+ ):
31
+ """
32
+ Args:
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the input/output images.
35
+ If an integer is provided, it's used for both height and width.
36
+ uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
37
+ When False, only the decoder is used (for latent-to-image conversion).
38
+ num_frames (Optional[int]): The number of frames in the generated video.
39
+ decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
40
+ Useful for managing memory usage during video generation.
41
+ vae_scale_factor (Optional[float]): The scaling factor between pixel space and latent space.
42
+ Determines how much smaller the latent representations are compared to the original images.
43
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
44
+
45
+ Raises:
46
+ ValueError: If batch_size is not a positive integer.
47
+ """
48
+ super().__init__(**kwargs)
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.uses_encoder = uses_encoder
54
+ self.num_frames = num_frames
55
+ self.decode_chunk_size = decode_chunk_size
56
+ self.vae_scale_factor = vae_scale_factor
57
+ self.sample_size = sample_size
58
+ if isinstance(sample_size, int):
59
+ self.sample_size = (sample_size, sample_size)
60
+
61
+ @property
62
+ def image_size(self):
63
+ return self.sample_size
64
+
65
+ @property
66
+ def latent_sample_size(self):
67
+ return (self.image_size[0] // self.vae_scale_factor, self.image_size[1] // self.vae_scale_factor)
@@ -0,0 +1,59 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNUNetSpatioTemporalConditionModelConfig(RBLNModelConfig):
21
+ subclass_non_save_attributes = ["_batch_size_is_specified"]
22
+
23
+ def __init__(
24
+ self,
25
+ batch_size: Optional[int] = None,
26
+ sample_size: Optional[Tuple[int, int]] = None,
27
+ in_features: Optional[int] = None,
28
+ num_frames: Optional[int] = None,
29
+ **kwargs: Any,
30
+ ):
31
+ """
32
+ Args:
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the generated samples.
35
+ If an integer is provided, it's used for both height and width.
36
+ in_features (Optional[int]): Number of input features for the model.
37
+ num_frames (Optional[int]): The number of frames in the generated video.
38
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
39
+
40
+ Raises:
41
+ ValueError: If batch_size is not a positive integer.
42
+ """
43
+ super().__init__(**kwargs)
44
+ self._batch_size_is_specified = batch_size is not None
45
+
46
+ self.batch_size = batch_size or 1
47
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
48
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
49
+
50
+ self.in_features = in_features
51
+ self.num_frames = num_frames
52
+
53
+ self.sample_size = sample_size
54
+ if isinstance(sample_size, int):
55
+ self.sample_size = (sample_size, sample_size)
56
+
57
+ @property
58
+ def batch_size_is_specified(self):
59
+ return self._batch_size_is_specified
@@ -29,3 +29,6 @@ from .configuration_stable_diffusion_xl import (
29
29
  RBLNStableDiffusionXLInpaintPipelineConfig,
30
30
  RBLNStableDiffusionXLPipelineConfig,
31
31
  )
32
+ from .configuration_stable_video_diffusion import (
33
+ RBLNStableVideoDiffusionPipelineConfig,
34
+ )