optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -91,6 +91,10 @@ _import_structure = {
|
|
|
91
91
|
"RBLNGemmaModel",
|
|
92
92
|
"RBLNGemmaModelConfig",
|
|
93
93
|
"RBLNGemmaForCausalLM",
|
|
94
|
+
"RBLNGemma2ForCausalLM",
|
|
95
|
+
"RBLNGemma2ForCausalLMConfig",
|
|
96
|
+
"RBLNGemma2Model",
|
|
97
|
+
"RBLNGemma2ModelConfig",
|
|
94
98
|
"RBLNGemmaForCausalLMConfig",
|
|
95
99
|
"RBLNGemma3ForCausalLM",
|
|
96
100
|
"RBLNGemma3ForCausalLMConfig",
|
|
@@ -100,6 +104,8 @@ _import_structure = {
|
|
|
100
104
|
"RBLNGPT2ModelConfig",
|
|
101
105
|
"RBLNGPT2LMHeadModel",
|
|
102
106
|
"RBLNGPT2LMHeadModelConfig",
|
|
107
|
+
"RBLNGptOssForCausalLM",
|
|
108
|
+
"RBLNGptOssForCausalLMConfig",
|
|
103
109
|
"RBLNGroundingDinoDecoder",
|
|
104
110
|
"RBLNGroundingDinoDecoderConfig",
|
|
105
111
|
"RBLNGroundingDinoForObjectDetection",
|
|
@@ -140,14 +146,24 @@ _import_structure = {
|
|
|
140
146
|
"RBLNPixtralVisionModelConfig",
|
|
141
147
|
"RBLNPhiModel",
|
|
142
148
|
"RBLNPhiModelConfig",
|
|
149
|
+
"RBLNPaliGemmaForConditionalGeneration",
|
|
150
|
+
"RBLNPaliGemmaForConditionalGenerationConfig",
|
|
151
|
+
"RBLNPaliGemmaModel",
|
|
152
|
+
"RBLNPaliGemmaModelConfig",
|
|
143
153
|
"RBLNQwen2ForCausalLM",
|
|
144
154
|
"RBLNQwen2ForCausalLMConfig",
|
|
145
155
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
146
156
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
147
157
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
148
158
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
159
|
+
"RBLNQwen3MoeForCausalLM",
|
|
160
|
+
"RBLNQwen3MoeForCausalLMConfig",
|
|
161
|
+
"RBLNQwen2_5_VLModel",
|
|
162
|
+
"RBLNQwen2_5_VLModelConfig",
|
|
149
163
|
"RBLNQwen2Model",
|
|
150
164
|
"RBLNQwen2ModelConfig",
|
|
165
|
+
"RBLNQwen2MoeForCausalLM",
|
|
166
|
+
"RBLNQwen2MoeForCausalLMConfig",
|
|
151
167
|
"RBLNQwen3ForCausalLM",
|
|
152
168
|
"RBLNQwen3ForCausalLMConfig",
|
|
153
169
|
"RBLNQwen3Model",
|
|
@@ -156,6 +172,8 @@ _import_structure = {
|
|
|
156
172
|
"RBLNQwen2VisionTransformerPretrainedModelConfig",
|
|
157
173
|
"RBLNQwen2VLForConditionalGeneration",
|
|
158
174
|
"RBLNQwen2VLForConditionalGenerationConfig",
|
|
175
|
+
"RBLNQwen2VLModel",
|
|
176
|
+
"RBLNQwen2VLModelConfig",
|
|
159
177
|
"RBLNResNetForImageClassification",
|
|
160
178
|
"RBLNResNetForImageClassificationConfig",
|
|
161
179
|
"RBLNRobertaForMaskedLM",
|
|
@@ -186,12 +204,16 @@ _import_structure = {
|
|
|
186
204
|
"diffusers": [
|
|
187
205
|
"RBLNAutoencoderKL",
|
|
188
206
|
"RBLNAutoencoderKLConfig",
|
|
207
|
+
"RBLNAutoencoderKLTemporalDecoder",
|
|
208
|
+
"RBLNAutoencoderKLTemporalDecoderConfig",
|
|
189
209
|
"RBLNAutoencoderKLCosmos",
|
|
190
210
|
"RBLNAutoencoderKLCosmosConfig",
|
|
191
211
|
"RBLNAutoPipelineForImage2Image",
|
|
192
212
|
"RBLNAutoPipelineForInpainting",
|
|
193
213
|
"RBLNAutoPipelineForText2Image",
|
|
194
214
|
"RBLNControlNetModel",
|
|
215
|
+
"RBLNUNetSpatioTemporalConditionModel",
|
|
216
|
+
"RBLNStableVideoDiffusionPipeline",
|
|
195
217
|
"RBLNControlNetModelConfig",
|
|
196
218
|
"RBLNCosmosTextToWorldPipeline",
|
|
197
219
|
"RBLNCosmosVideoToWorldPipeline",
|
|
@@ -250,6 +272,8 @@ _import_structure = {
|
|
|
250
272
|
"RBLNUNet2DConditionModelConfig",
|
|
251
273
|
"RBLNVQModel",
|
|
252
274
|
"RBLNVQModelConfig",
|
|
275
|
+
"RBLNUNetSpatioTemporalConditionModelConfig",
|
|
276
|
+
"RBLNStableVideoDiffusionPipelineConfig",
|
|
253
277
|
],
|
|
254
278
|
}
|
|
255
279
|
|
|
@@ -260,6 +284,8 @@ if TYPE_CHECKING:
|
|
|
260
284
|
RBLNAutoencoderKLConfig,
|
|
261
285
|
RBLNAutoencoderKLCosmos,
|
|
262
286
|
RBLNAutoencoderKLCosmosConfig,
|
|
287
|
+
RBLNAutoencoderKLTemporalDecoder,
|
|
288
|
+
RBLNAutoencoderKLTemporalDecoderConfig,
|
|
263
289
|
RBLNAutoPipelineForImage2Image,
|
|
264
290
|
RBLNAutoPipelineForInpainting,
|
|
265
291
|
RBLNAutoPipelineForText2Image,
|
|
@@ -318,8 +344,12 @@ if TYPE_CHECKING:
|
|
|
318
344
|
RBLNStableDiffusionXLInpaintPipelineConfig,
|
|
319
345
|
RBLNStableDiffusionXLPipeline,
|
|
320
346
|
RBLNStableDiffusionXLPipelineConfig,
|
|
347
|
+
RBLNStableVideoDiffusionPipeline,
|
|
348
|
+
RBLNStableVideoDiffusionPipelineConfig,
|
|
321
349
|
RBLNUNet2DConditionModel,
|
|
322
350
|
RBLNUNet2DConditionModelConfig,
|
|
351
|
+
RBLNUNetSpatioTemporalConditionModel,
|
|
352
|
+
RBLNUNetSpatioTemporalConditionModelConfig,
|
|
323
353
|
RBLNVQModel,
|
|
324
354
|
RBLNVQModelConfig,
|
|
325
355
|
)
|
|
@@ -382,6 +412,10 @@ if TYPE_CHECKING:
|
|
|
382
412
|
RBLNDPTForDepthEstimationConfig,
|
|
383
413
|
RBLNExaoneForCausalLM,
|
|
384
414
|
RBLNExaoneForCausalLMConfig,
|
|
415
|
+
RBLNGemma2ForCausalLM,
|
|
416
|
+
RBLNGemma2ForCausalLMConfig,
|
|
417
|
+
RBLNGemma2Model,
|
|
418
|
+
RBLNGemma2ModelConfig,
|
|
385
419
|
RBLNGemma3ForCausalLM,
|
|
386
420
|
RBLNGemma3ForCausalLMConfig,
|
|
387
421
|
RBLNGemma3ForConditionalGeneration,
|
|
@@ -394,6 +428,8 @@ if TYPE_CHECKING:
|
|
|
394
428
|
RBLNGPT2LMHeadModelConfig,
|
|
395
429
|
RBLNGPT2Model,
|
|
396
430
|
RBLNGPT2ModelConfig,
|
|
431
|
+
RBLNGptOssForCausalLM,
|
|
432
|
+
RBLNGptOssForCausalLMConfig,
|
|
397
433
|
RBLNGroundingDinoDecoder,
|
|
398
434
|
RBLNGroundingDinoDecoderConfig,
|
|
399
435
|
RBLNGroundingDinoEncoder,
|
|
@@ -424,6 +460,10 @@ if TYPE_CHECKING:
|
|
|
424
460
|
RBLNOPTForCausalLMConfig,
|
|
425
461
|
RBLNOPTModel,
|
|
426
462
|
RBLNOPTModelConfig,
|
|
463
|
+
RBLNPaliGemmaForConditionalGeneration,
|
|
464
|
+
RBLNPaliGemmaForConditionalGenerationConfig,
|
|
465
|
+
RBLNPaliGemmaModel,
|
|
466
|
+
RBLNPaliGemmaModelConfig,
|
|
427
467
|
RBLNPegasusForConditionalGeneration,
|
|
428
468
|
RBLNPegasusForConditionalGenerationConfig,
|
|
429
469
|
RBLNPegasusModel,
|
|
@@ -438,18 +478,26 @@ if TYPE_CHECKING:
|
|
|
438
478
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
439
479
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
440
480
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
481
|
+
RBLNQwen2_5_VLModel,
|
|
482
|
+
RBLNQwen2_5_VLModelConfig,
|
|
441
483
|
RBLNQwen2ForCausalLM,
|
|
442
484
|
RBLNQwen2ForCausalLMConfig,
|
|
443
485
|
RBLNQwen2Model,
|
|
444
486
|
RBLNQwen2ModelConfig,
|
|
487
|
+
RBLNQwen2MoeForCausalLM,
|
|
488
|
+
RBLNQwen2MoeForCausalLMConfig,
|
|
445
489
|
RBLNQwen2VisionTransformerPretrainedModel,
|
|
446
490
|
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
447
491
|
RBLNQwen2VLForConditionalGeneration,
|
|
448
492
|
RBLNQwen2VLForConditionalGenerationConfig,
|
|
493
|
+
RBLNQwen2VLModel,
|
|
494
|
+
RBLNQwen2VLModelConfig,
|
|
449
495
|
RBLNQwen3ForCausalLM,
|
|
450
496
|
RBLNQwen3ForCausalLMConfig,
|
|
451
497
|
RBLNQwen3Model,
|
|
452
498
|
RBLNQwen3ModelConfig,
|
|
499
|
+
RBLNQwen3MoeForCausalLM,
|
|
500
|
+
RBLNQwen3MoeForCausalLMConfig,
|
|
453
501
|
RBLNResNetForImageClassification,
|
|
454
502
|
RBLNResNetForImageClassificationConfig,
|
|
455
503
|
RBLNRobertaForMaskedLM,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.9.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 9,
|
|
31
|
+
__version__ = version = '0.9.5a4'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 9, 5, 'a4')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from packaging.version import Version
|
|
25
25
|
|
|
26
26
|
from .__version__ import __version__
|
|
27
|
-
from .utils.
|
|
27
|
+
from .utils.deprecation import deprecate_kwarg, warn_deprecated_npu
|
|
28
28
|
from .utils.logging import get_logger
|
|
29
29
|
from .utils.runtime_utils import ContextRblnConfig
|
|
30
30
|
|
|
@@ -92,7 +92,7 @@ class RBLNCompileConfig:
|
|
|
92
92
|
and isinstance(item[0], str) # name
|
|
93
93
|
and isinstance(item[1], (tuple, list)) # shape
|
|
94
94
|
and all(isinstance(x, int) for x in item[1])
|
|
95
|
-
and isinstance(item[2], str) # dtype
|
|
95
|
+
and (isinstance(item[2], str) or isinstance(item[2], torch.dtype)) # dtype
|
|
96
96
|
for item in input_info
|
|
97
97
|
)
|
|
98
98
|
|
|
@@ -117,9 +117,14 @@ class RBLNCompileConfig:
|
|
|
117
117
|
return self
|
|
118
118
|
|
|
119
119
|
def get_dummy_inputs(
|
|
120
|
-
self,
|
|
120
|
+
self,
|
|
121
|
+
fill=0,
|
|
122
|
+
static_tensors: Optional[Dict[str, torch.Tensor]] = None,
|
|
123
|
+
meta_tensor_names: Optional[List[str]] = None,
|
|
121
124
|
):
|
|
122
125
|
dummy = []
|
|
126
|
+
static_tensors = static_tensors if static_tensors is not None else {}
|
|
127
|
+
meta_tensor_names = meta_tensor_names if meta_tensor_names is not None else []
|
|
123
128
|
for name, shape, dtype in self.input_info:
|
|
124
129
|
if name in static_tensors:
|
|
125
130
|
tensor = static_tensors[name]
|
|
@@ -255,7 +260,7 @@ class RBLNAutoConfig:
|
|
|
255
260
|
def load(
|
|
256
261
|
path: str,
|
|
257
262
|
passed_rbln_config: Optional["RBLNModelConfig"] = None,
|
|
258
|
-
kwargs: Optional[Dict[str, Any]] =
|
|
263
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
259
264
|
return_unused_kwargs: bool = False,
|
|
260
265
|
) -> Union["RBLNModelConfig", Tuple["RBLNModelConfig", Dict[str, Any]]]:
|
|
261
266
|
"""
|
|
@@ -269,6 +274,8 @@ class RBLNAutoConfig:
|
|
|
269
274
|
Returns:
|
|
270
275
|
RBLNModelConfig: The loaded RBLNModelConfig.
|
|
271
276
|
"""
|
|
277
|
+
if kwargs is None:
|
|
278
|
+
kwargs = {}
|
|
272
279
|
cls, config_file = load_config(path)
|
|
273
280
|
|
|
274
281
|
rbln_keys = [key for key in kwargs.keys() if key.startswith("rbln_")]
|
|
@@ -517,8 +524,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
517
524
|
non_save_attributes = [
|
|
518
525
|
"_frozen",
|
|
519
526
|
"_runtime_options",
|
|
520
|
-
"torch_dtype",
|
|
521
527
|
"npu",
|
|
528
|
+
"dtype",
|
|
522
529
|
"tensor_parallel_size",
|
|
523
530
|
"create_runtimes",
|
|
524
531
|
"device",
|
|
@@ -528,6 +535,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
528
535
|
]
|
|
529
536
|
submodules: List[str] = []
|
|
530
537
|
subclass_non_save_attributes = []
|
|
538
|
+
_allow_no_compile_cfgs = False
|
|
531
539
|
|
|
532
540
|
def initialize_submodule_config(
|
|
533
541
|
self,
|
|
@@ -642,6 +650,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
642
650
|
|
|
643
651
|
super().__setattr__(key, value)
|
|
644
652
|
|
|
653
|
+
@deprecate_kwarg(
|
|
654
|
+
old_name="_torch_dtype",
|
|
655
|
+
new_name="dtype",
|
|
656
|
+
version="0.12.0",
|
|
657
|
+
deprecated_type=torch.dtype,
|
|
658
|
+
value_replacer=RBLNCompileConfig.normalize_dtype,
|
|
659
|
+
raise_if_greater_or_equal_version=False,
|
|
660
|
+
)
|
|
645
661
|
def __init__(
|
|
646
662
|
self,
|
|
647
663
|
cls_name: Optional[str] = None,
|
|
@@ -653,8 +669,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
653
669
|
tensor_parallel_size: Optional[int] = None,
|
|
654
670
|
timeout: Optional[int] = None,
|
|
655
671
|
optimum_rbln_version: Optional[str] = None,
|
|
656
|
-
|
|
657
|
-
_compile_cfgs: List[RBLNCompileConfig] =
|
|
672
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
673
|
+
_compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
|
|
658
674
|
*,
|
|
659
675
|
optimize_host_memory: Optional[bool] = None,
|
|
660
676
|
**kwargs: Any,
|
|
@@ -672,7 +688,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
672
688
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
|
|
673
689
|
timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
|
|
674
690
|
optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
|
|
675
|
-
|
|
691
|
+
dtype (Optional[Union[str, torch.dtype]]): The data type to use for the model.
|
|
676
692
|
_compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
|
|
677
693
|
kwargs: Additional keyword arguments.
|
|
678
694
|
|
|
@@ -702,12 +718,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
702
718
|
self.npu = npu
|
|
703
719
|
self.tensor_parallel_size = tensor_parallel_size
|
|
704
720
|
|
|
705
|
-
|
|
721
|
+
if dtype is not None and isinstance(dtype, torch.dtype):
|
|
722
|
+
dtype = RBLNCompileConfig.normalize_dtype(dtype)
|
|
723
|
+
self._dtype = dtype or "float32"
|
|
706
724
|
self.optimum_rbln_version = optimum_rbln_version
|
|
707
725
|
if self.optimum_rbln_version is None:
|
|
708
726
|
self.optimum_rbln_version = __version__
|
|
709
727
|
|
|
710
|
-
|
|
728
|
+
compile_cfgs = _compile_cfgs if _compile_cfgs is not None else []
|
|
729
|
+
self._compile_cfgs: List[RBLNCompileConfig] = compile_cfgs
|
|
711
730
|
|
|
712
731
|
if not isinstance(self._compile_cfgs, list):
|
|
713
732
|
raise ValueError("`compile_cfgs` must be a list of `RBLNCompileConfig`.")
|
|
@@ -734,14 +753,24 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
734
753
|
|
|
735
754
|
@property
|
|
736
755
|
def torch_dtype(self):
|
|
737
|
-
|
|
756
|
+
logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
|
|
757
|
+
return self.dtype
|
|
738
758
|
|
|
739
759
|
@torch_dtype.setter
|
|
740
760
|
def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
|
|
741
|
-
|
|
742
|
-
|
|
761
|
+
logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
|
|
762
|
+
self.dtype = torch_dtype
|
|
763
|
+
|
|
764
|
+
@property
|
|
765
|
+
def dtype(self):
|
|
766
|
+
return getattr(torch, self._dtype)
|
|
767
|
+
|
|
768
|
+
@dtype.setter
|
|
769
|
+
def dtype(self, dtype: Union[str, torch.dtype]):
|
|
770
|
+
if isinstance(dtype, torch.dtype):
|
|
771
|
+
dtype = RBLNCompileConfig.normalize_dtype(dtype)
|
|
743
772
|
|
|
744
|
-
self.
|
|
773
|
+
self._dtype = dtype
|
|
745
774
|
|
|
746
775
|
@property
|
|
747
776
|
def rbln_model_cls_name(self) -> str:
|
|
@@ -765,10 +794,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
765
794
|
if isinstance(value, RBLNSerializableConfigProtocol):
|
|
766
795
|
# Convert nested RBLNModelConfig to its serializable form
|
|
767
796
|
serializable_map[key] = value._prepare_for_serialization()
|
|
797
|
+
elif key == "_dtype":
|
|
798
|
+
serializable_map["dtype"] = value
|
|
799
|
+
elif isinstance(value, list) and all(isinstance(item, RBLNSerializableConfigProtocol) for item in value):
|
|
800
|
+
serializable_map[key] = [item._prepare_for_serialization() for item in value]
|
|
768
801
|
elif key == "_compile_cfgs":
|
|
769
802
|
serializable_map[key] = [cfg.asdict() for cfg in value]
|
|
770
803
|
else:
|
|
771
804
|
serializable_map[key] = value
|
|
805
|
+
|
|
772
806
|
return serializable_map
|
|
773
807
|
|
|
774
808
|
def __repr__(self):
|
|
@@ -808,25 +842,20 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
808
842
|
or len(self._compile_cfgs) == 0
|
|
809
843
|
or not all(isinstance(cfg, RBLNCompileConfig) for cfg in self._compile_cfgs)
|
|
810
844
|
):
|
|
811
|
-
|
|
845
|
+
if not self._allow_no_compile_cfgs:
|
|
846
|
+
raise RuntimeError("`compile_cfgs` must contain at least one `RBLNCompileConfig` before freezing.")
|
|
812
847
|
|
|
813
848
|
for submodule_name in self.submodules:
|
|
814
849
|
submodule_config = getattr(self, submodule_name, None)
|
|
815
850
|
if not isinstance(submodule_config, RBLNModelConfig):
|
|
816
851
|
raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
|
|
817
852
|
|
|
818
|
-
if not submodule_config.is_frozen():
|
|
819
|
-
raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
|
|
820
|
-
|
|
821
853
|
self._frozen = True
|
|
822
854
|
|
|
823
855
|
def is_frozen(self):
|
|
824
856
|
return self._frozen
|
|
825
857
|
|
|
826
858
|
def save(self, path: str):
|
|
827
|
-
if not self._frozen:
|
|
828
|
-
raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
|
|
829
|
-
|
|
830
859
|
# save as json file without runtime attributes
|
|
831
860
|
path = Path(path)
|
|
832
861
|
if path.is_dir():
|
|
@@ -57,6 +57,9 @@ _import_structure = {
|
|
|
57
57
|
"RBLNSD3Transformer2DModelConfig",
|
|
58
58
|
"RBLNUNet2DConditionModelConfig",
|
|
59
59
|
"RBLNVQModelConfig",
|
|
60
|
+
"RBLNUNetSpatioTemporalConditionModelConfig",
|
|
61
|
+
"RBLNStableVideoDiffusionPipelineConfig",
|
|
62
|
+
"RBLNAutoencoderKLTemporalDecoderConfig",
|
|
60
63
|
],
|
|
61
64
|
"pipelines": [
|
|
62
65
|
"RBLNAutoPipelineForImage2Image",
|
|
@@ -86,14 +89,17 @@ _import_structure = {
|
|
|
86
89
|
"RBLNStableDiffusion3Pipeline",
|
|
87
90
|
"RBLNStableDiffusion3Img2ImgPipeline",
|
|
88
91
|
"RBLNStableDiffusion3InpaintPipeline",
|
|
92
|
+
"RBLNStableVideoDiffusionPipeline",
|
|
89
93
|
],
|
|
90
94
|
"models": [
|
|
91
95
|
"RBLNAutoencoderKL",
|
|
92
96
|
"RBLNAutoencoderKLCosmos",
|
|
93
97
|
"RBLNUNet2DConditionModel",
|
|
98
|
+
"RBLNUNetSpatioTemporalConditionModel",
|
|
94
99
|
"RBLNControlNetModel",
|
|
95
100
|
"RBLNCosmosTransformer3DModel",
|
|
96
101
|
"RBLNSD3Transformer2DModel",
|
|
102
|
+
"RBLNAutoencoderKLTemporalDecoder",
|
|
97
103
|
"RBLNPriorTransformer",
|
|
98
104
|
"RBLNVQModel",
|
|
99
105
|
],
|
|
@@ -106,6 +112,7 @@ if TYPE_CHECKING:
|
|
|
106
112
|
from .configurations import (
|
|
107
113
|
RBLNAutoencoderKLConfig,
|
|
108
114
|
RBLNAutoencoderKLCosmosConfig,
|
|
115
|
+
RBLNAutoencoderKLTemporalDecoderConfig,
|
|
109
116
|
RBLNControlNetModelConfig,
|
|
110
117
|
RBLNCosmosTextToWorldPipelineConfig,
|
|
111
118
|
RBLNCosmosTransformer3DModelConfig,
|
|
@@ -132,18 +139,22 @@ if TYPE_CHECKING:
|
|
|
132
139
|
RBLNStableDiffusionXLImg2ImgPipelineConfig,
|
|
133
140
|
RBLNStableDiffusionXLInpaintPipelineConfig,
|
|
134
141
|
RBLNStableDiffusionXLPipelineConfig,
|
|
142
|
+
RBLNStableVideoDiffusionPipelineConfig,
|
|
135
143
|
RBLNUNet2DConditionModelConfig,
|
|
144
|
+
RBLNUNetSpatioTemporalConditionModelConfig,
|
|
136
145
|
RBLNVQModelConfig,
|
|
137
146
|
)
|
|
138
147
|
from .modeling_diffusers import RBLNDiffusionMixin
|
|
139
148
|
from .models import (
|
|
140
149
|
RBLNAutoencoderKL,
|
|
141
150
|
RBLNAutoencoderKLCosmos,
|
|
151
|
+
RBLNAutoencoderKLTemporalDecoder,
|
|
142
152
|
RBLNControlNetModel,
|
|
143
153
|
RBLNCosmosTransformer3DModel,
|
|
144
154
|
RBLNPriorTransformer,
|
|
145
155
|
RBLNSD3Transformer2DModel,
|
|
146
156
|
RBLNUNet2DConditionModel,
|
|
157
|
+
RBLNUNetSpatioTemporalConditionModel,
|
|
147
158
|
RBLNVQModel,
|
|
148
159
|
)
|
|
149
160
|
from .pipelines import (
|
|
@@ -174,6 +185,7 @@ if TYPE_CHECKING:
|
|
|
174
185
|
RBLNStableDiffusionXLImg2ImgPipeline,
|
|
175
186
|
RBLNStableDiffusionXLInpaintPipeline,
|
|
176
187
|
RBLNStableDiffusionXLPipeline,
|
|
188
|
+
RBLNStableVideoDiffusionPipeline,
|
|
177
189
|
)
|
|
178
190
|
else:
|
|
179
191
|
import sys
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from .models import (
|
|
2
2
|
RBLNAutoencoderKLConfig,
|
|
3
3
|
RBLNAutoencoderKLCosmosConfig,
|
|
4
|
+
RBLNAutoencoderKLTemporalDecoderConfig,
|
|
4
5
|
RBLNControlNetModelConfig,
|
|
5
6
|
RBLNCosmosTransformer3DModelConfig,
|
|
6
7
|
RBLNPriorTransformerConfig,
|
|
7
8
|
RBLNSD3Transformer2DModelConfig,
|
|
8
9
|
RBLNUNet2DConditionModelConfig,
|
|
10
|
+
RBLNUNetSpatioTemporalConditionModelConfig,
|
|
9
11
|
RBLNVQModelConfig,
|
|
10
12
|
)
|
|
11
13
|
from .pipelines import (
|
|
@@ -31,4 +33,5 @@ from .pipelines import (
|
|
|
31
33
|
RBLNStableDiffusionXLImg2ImgPipelineConfig,
|
|
32
34
|
RBLNStableDiffusionXLInpaintPipelineConfig,
|
|
33
35
|
RBLNStableDiffusionXLPipelineConfig,
|
|
36
|
+
RBLNStableVideoDiffusionPipelineConfig,
|
|
34
37
|
)
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from .configuration_autoencoder_kl import RBLNAutoencoderKLConfig
|
|
2
2
|
from .configuration_autoencoder_kl_cosmos import RBLNAutoencoderKLCosmosConfig
|
|
3
|
+
from .configuration_autoencoder_kl_temporal_decoder import RBLNAutoencoderKLTemporalDecoderConfig
|
|
3
4
|
from .configuration_controlnet import RBLNControlNetModelConfig
|
|
4
5
|
from .configuration_prior_transformer import RBLNPriorTransformerConfig
|
|
5
6
|
from .configuration_transformer_cosmos import RBLNCosmosTransformer3DModelConfig
|
|
6
7
|
from .configuration_transformer_sd3 import RBLNSD3Transformer2DModelConfig
|
|
7
8
|
from .configuration_unet_2d_condition import RBLNUNet2DConditionModelConfig
|
|
9
|
+
from .configuration_unet_spatio_temporal_condition import RBLNUNetSpatioTemporalConditionModelConfig
|
|
8
10
|
from .configuration_vq_model import RBLNVQModelConfig
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNAutoencoderKLTemporalDecoderConfig(RBLNModelConfig):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
batch_size: Optional[int] = None,
|
|
24
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
|
25
|
+
uses_encoder: Optional[bool] = None,
|
|
26
|
+
num_frames: Optional[int] = None,
|
|
27
|
+
decode_chunk_size: Optional[int] = None,
|
|
28
|
+
vae_scale_factor: Optional[float] = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
34
|
+
sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the input/output images.
|
|
35
|
+
If an integer is provided, it's used for both height and width.
|
|
36
|
+
uses_encoder (Optional[bool]): Whether to include the encoder part of the VAE in the model.
|
|
37
|
+
When False, only the decoder is used (for latent-to-image conversion).
|
|
38
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
|
39
|
+
decode_chunk_size (Optional[int]): The number of frames to decode at once during VAE decoding.
|
|
40
|
+
Useful for managing memory usage during video generation.
|
|
41
|
+
vae_scale_factor (Optional[float]): The scaling factor between pixel space and latent space.
|
|
42
|
+
Determines how much smaller the latent representations are compared to the original images.
|
|
43
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If batch_size is not a positive integer.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(**kwargs)
|
|
49
|
+
self.batch_size = batch_size or 1
|
|
50
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
51
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
52
|
+
|
|
53
|
+
self.uses_encoder = uses_encoder
|
|
54
|
+
self.num_frames = num_frames
|
|
55
|
+
self.decode_chunk_size = decode_chunk_size
|
|
56
|
+
self.vae_scale_factor = vae_scale_factor
|
|
57
|
+
self.sample_size = sample_size
|
|
58
|
+
if isinstance(sample_size, int):
|
|
59
|
+
self.sample_size = (sample_size, sample_size)
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def image_size(self):
|
|
63
|
+
return self.sample_size
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def latent_sample_size(self):
|
|
67
|
+
return (self.image_size[0] // self.vae_scale_factor, self.image_size[1] // self.vae_scale_factor)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNUNetSpatioTemporalConditionModelConfig(RBLNModelConfig):
|
|
21
|
+
subclass_non_save_attributes = ["_batch_size_is_specified"]
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
batch_size: Optional[int] = None,
|
|
26
|
+
sample_size: Optional[Tuple[int, int]] = None,
|
|
27
|
+
in_features: Optional[int] = None,
|
|
28
|
+
num_frames: Optional[int] = None,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Args:
|
|
33
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
34
|
+
sample_size (Optional[Tuple[int, int]]): The spatial dimensions (height, width) of the generated samples.
|
|
35
|
+
If an integer is provided, it's used for both height and width.
|
|
36
|
+
in_features (Optional[int]): Number of input features for the model.
|
|
37
|
+
num_frames (Optional[int]): The number of frames in the generated video.
|
|
38
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If batch_size is not a positive integer.
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(**kwargs)
|
|
44
|
+
self._batch_size_is_specified = batch_size is not None
|
|
45
|
+
|
|
46
|
+
self.batch_size = batch_size or 1
|
|
47
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
48
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
49
|
+
|
|
50
|
+
self.in_features = in_features
|
|
51
|
+
self.num_frames = num_frames
|
|
52
|
+
|
|
53
|
+
self.sample_size = sample_size
|
|
54
|
+
if isinstance(sample_size, int):
|
|
55
|
+
self.sample_size = (sample_size, sample_size)
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def batch_size_is_specified(self):
|
|
59
|
+
return self._batch_size_is_specified
|