optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/ops/attn.py CHANGED
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
53
53
  return torch.empty_like(q)
54
54
 
55
55
 
56
+ @torch.library.custom_op(
57
+ "rbln_custom_ops::paged_attn_decode_kv_fp8",
58
+ mutates_args=(["kcache", "vcache"]),
59
+ )
60
+ def paged_attn_decode_kv_fp8(
61
+ q: Tensor,
62
+ k: Tensor,
63
+ v: Tensor,
64
+ mask: Tensor,
65
+ kcache: Tensor,
66
+ vcache: Tensor,
67
+ seq: Tensor,
68
+ scale: Tensor,
69
+ block_table: Tensor,
70
+ block_size: int,
71
+ k_scale: Tensor,
72
+ v_scale: Tensor,
73
+ ) -> Tensor:
74
+ return torch.empty_like(q)
75
+
76
+
77
+ @paged_attn_decode_kv_fp8.register_fake
78
+ def paged_attn_decode_kv_fp8_fake(
79
+ q: Tensor,
80
+ k: Tensor,
81
+ v: Tensor,
82
+ mask: Tensor,
83
+ kcache: Tensor,
84
+ vcache: Tensor,
85
+ seq: Tensor,
86
+ scale: Tensor,
87
+ block_table: Tensor,
88
+ block_size: int,
89
+ k_scale: Tensor,
90
+ v_scale: Tensor,
91
+ ) -> Tensor:
92
+ return torch.empty_like(q)
93
+
94
+
56
95
  @torch.library.custom_op(
57
96
  "rbln_custom_ops::paged_attn_prefill",
58
97
  mutates_args=(["kcache", "vcache"]),
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
112
151
  return torch.empty_like(q)
113
152
 
114
153
 
154
+ @torch.library.custom_op(
155
+ "rbln_custom_ops::paged_attn_prefill_kv_fp8",
156
+ mutates_args=(["kcache", "vcache"]),
157
+ )
158
+ def paged_attn_prefill_kv_fp8(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ mask: Tensor,
163
+ kcache: Tensor,
164
+ vcache: Tensor,
165
+ seq: Tensor,
166
+ scale: Tensor,
167
+ block_table: Tensor,
168
+ block_size: int,
169
+ k_scale: Tensor,
170
+ v_scale: Tensor,
171
+ ) -> Tensor:
172
+ return torch.empty_like(q)
173
+
174
+
175
+ @paged_attn_prefill_kv_fp8.register_fake
176
+ def paged_attn_prefill_kv_fp8_fake(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ mask: Tensor,
181
+ kcache: Tensor,
182
+ vcache: Tensor,
183
+ seq: Tensor,
184
+ scale: Tensor,
185
+ block_table: Tensor,
186
+ block_size: int,
187
+ k_scale: Tensor,
188
+ v_scale: Tensor,
189
+ ) -> Tensor:
190
+ return torch.empty_like(q)
191
+
192
+
115
193
  @torch.library.custom_op(
116
194
  "rbln_custom_ops::paged_causal_attn_decode",
117
195
  mutates_args=(["kcache", "vcache"]),
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
236
314
  return torch.empty_like(q)
237
315
 
238
316
 
317
+ @torch.library.custom_op(
318
+ "rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
319
+ mutates_args=(["kcache", "vcache"]),
320
+ )
321
+ def paged_causal_attn_decode_kv_fp8(
322
+ q: Tensor,
323
+ k: Tensor,
324
+ v: Tensor,
325
+ kcache: Tensor,
326
+ vcache: Tensor,
327
+ seq: Tensor,
328
+ scale: Tensor,
329
+ block_table: Tensor,
330
+ block_size: int,
331
+ k_scale: Tensor,
332
+ v_scale: Tensor,
333
+ mask: Optional[Tensor] = None,
334
+ ) -> Tensor:
335
+ return torch.empty_like(q)
336
+
337
+
338
+ @paged_causal_attn_decode_kv_fp8.register_fake
339
+ def paged_causal_attn_decode_kv_fp8_fake(
340
+ q: Tensor,
341
+ k: Tensor,
342
+ v: Tensor,
343
+ kcache: Tensor,
344
+ vcache: Tensor,
345
+ seq: Tensor,
346
+ scale: Tensor,
347
+ block_table: Tensor,
348
+ block_size: int,
349
+ k_scale: Tensor,
350
+ v_scale: Tensor,
351
+ mask: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ return torch.empty_like(q)
354
+
355
+
356
+ @torch.library.custom_op(
357
+ "rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
358
+ mutates_args=(["kcache", "vcache"]),
359
+ )
360
+ def paged_causal_attn_prefill_kv_fp8(
361
+ q: Tensor,
362
+ k: Tensor,
363
+ v: Tensor,
364
+ kcache: Tensor,
365
+ vcache: Tensor,
366
+ seq: Tensor,
367
+ scale: Tensor,
368
+ block_table: Tensor,
369
+ block_size: int,
370
+ is_bidirectional: bool,
371
+ k_scale: Tensor,
372
+ v_scale: Tensor,
373
+ mask: Optional[Tensor] = None,
374
+ ) -> Tensor:
375
+ return torch.empty_like(q)
376
+
377
+
378
+ @paged_causal_attn_prefill_kv_fp8.register_fake
379
+ def paged_causal_attn_prefill_kv_fp8_fake(
380
+ q: Tensor,
381
+ k: Tensor,
382
+ v: Tensor,
383
+ kcache: Tensor,
384
+ vcache: Tensor,
385
+ seq: Tensor,
386
+ scale: Tensor,
387
+ block_table: Tensor,
388
+ block_size: int,
389
+ is_bidirectional: bool,
390
+ k_scale: Tensor,
391
+ v_scale: Tensor,
392
+ mask: Optional[Tensor] = None,
393
+ ) -> Tensor:
394
+ return torch.empty_like(q)
395
+
396
+
239
397
  @torch.library.custom_op(
240
398
  "rbln_custom_ops::paged_add_softmax_attn_decode",
241
399
  mutates_args=(["kcache", "vcache"]),
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
59
59
  return torch.empty_like(q)
60
60
 
61
61
 
62
+ @torch.library.custom_op(
63
+ "rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
64
+ mutates_args=(["kcache", "vcache"]),
65
+ )
66
+ def paged_flash_attn_decode_kv_fp8(
67
+ q: Tensor,
68
+ k: Tensor,
69
+ v: Tensor,
70
+ mask: Tensor,
71
+ kcache: Tensor,
72
+ vcache: Tensor,
73
+ seq: Tensor,
74
+ scale: Tensor,
75
+ block_table: Tensor,
76
+ block_size: int,
77
+ partition: int,
78
+ k_scale: Tensor,
79
+ v_scale: Tensor,
80
+ ) -> Tensor:
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_decode_kv_fp8.register_fake
85
+ def paged_flash_attn_decode_kv_fp8_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ k_scale: Tensor,
98
+ v_scale: Tensor,
99
+ ) -> Tensor:
100
+ return torch.empty_like(q)
101
+
102
+
62
103
  @torch.library.custom_op(
63
104
  "rbln_custom_ops::paged_flash_attn_prefill",
64
105
  mutates_args=(["kcache", "vcache"]),
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
100
141
  return torch.empty_like(q)
101
142
 
102
143
 
144
+ @torch.library.custom_op(
145
+ "rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
146
+ mutates_args=(["kcache", "vcache"]),
147
+ )
148
+ def paged_flash_attn_prefill_kv_fp8(
149
+ q: Tensor,
150
+ k: Tensor,
151
+ v: Tensor,
152
+ mask: Tensor,
153
+ kcache: Tensor,
154
+ vcache: Tensor,
155
+ seq: Tensor,
156
+ scale: Tensor,
157
+ block_table: Tensor,
158
+ block_size: int,
159
+ partition: int,
160
+ k_scale: Tensor,
161
+ v_scale: Tensor,
162
+ ) -> Tensor:
163
+ return torch.empty_like(q)
164
+
165
+
166
+ @paged_flash_attn_prefill_kv_fp8.register_fake
167
+ def paged_flash_attn_prefill_kv_fp8_fake(
168
+ q: Tensor,
169
+ k: Tensor,
170
+ v: Tensor,
171
+ mask: Tensor,
172
+ kcache: Tensor,
173
+ vcache: Tensor,
174
+ seq: Tensor,
175
+ scale: Tensor,
176
+ block_table: Tensor,
177
+ block_size: int,
178
+ partition: int,
179
+ k_scale: Tensor,
180
+ v_scale: Tensor,
181
+ ) -> Tensor:
182
+ return torch.empty_like(q)
183
+
184
+
103
185
  @torch.library.custom_op(
104
186
  "rbln_custom_ops::paged_flash_causal_attn_decode",
105
187
  mutates_args=(["kcache", "vcache"]),
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
141
223
  return torch.empty_like(q)
142
224
 
143
225
 
226
+ @torch.library.custom_op(
227
+ "rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
228
+ mutates_args=(["kcache", "vcache"]),
229
+ )
230
+ def paged_flash_causal_attn_decode_kv_fp8(
231
+ q: Tensor,
232
+ k: Tensor,
233
+ v: Tensor,
234
+ kcache: Tensor,
235
+ vcache: Tensor,
236
+ seq: Tensor,
237
+ scale: Tensor,
238
+ block_table: Tensor,
239
+ block_size: int,
240
+ partition: int,
241
+ k_scale: Tensor,
242
+ v_scale: Tensor,
243
+ mask: Optional[Tensor] = None,
244
+ ) -> Tensor:
245
+ return torch.empty_like(q)
246
+
247
+
248
+ @paged_flash_causal_attn_decode_kv_fp8.register_fake
249
+ def paged_flash_causal_attn_decode_kv_fp8_fake(
250
+ q: Tensor,
251
+ k: Tensor,
252
+ v: Tensor,
253
+ kcache: Tensor,
254
+ vcache: Tensor,
255
+ seq: Tensor,
256
+ scale: Tensor,
257
+ block_table: Tensor,
258
+ block_size: int,
259
+ partition: int,
260
+ k_scale: Tensor,
261
+ v_scale: Tensor,
262
+ mask: Optional[Tensor] = None,
263
+ ) -> Tensor:
264
+ return torch.empty_like(q)
265
+
266
+
144
267
  @torch.library.custom_op(
145
268
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
146
269
  mutates_args=(["kcache", "vcache"]),
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
182
305
  mask: Optional[Tensor] = None,
183
306
  ) -> Tensor:
184
307
  return torch.empty_like(q)
308
+
309
+
310
+ @torch.library.custom_op(
311
+ "rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
312
+ mutates_args=(["kcache", "vcache"]),
313
+ )
314
+ def paged_flash_causal_attn_prefill_kv_fp8(
315
+ q: Tensor,
316
+ k: Tensor,
317
+ v: Tensor,
318
+ kcache: Tensor,
319
+ vcache: Tensor,
320
+ seq: Tensor,
321
+ scale: Tensor,
322
+ block_table: Tensor,
323
+ block_size: int,
324
+ partition: int,
325
+ is_bidirectional: bool,
326
+ k_scale: Tensor,
327
+ v_scale: Tensor,
328
+ mask: Optional[Tensor] = None,
329
+ ) -> Tensor:
330
+ return torch.empty_like(q)
331
+
332
+
333
+ @paged_flash_causal_attn_prefill_kv_fp8.register_fake
334
+ def paged_flash_causal_attn_prefill_kv_fp8_fake(
335
+ q: Tensor,
336
+ k: Tensor,
337
+ v: Tensor,
338
+ kcache: Tensor,
339
+ vcache: Tensor,
340
+ seq: Tensor,
341
+ scale: Tensor,
342
+ block_table: Tensor,
343
+ block_size: int,
344
+ partition: int,
345
+ is_bidirectional: bool,
346
+ k_scale: Tensor,
347
+ v_scale: Tensor,
348
+ mask: Optional[Tensor] = None,
349
+ ) -> Tensor:
350
+ return torch.empty_like(q)
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
22
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
23
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
24
24
  return torch.empty_like(cache)
25
+
26
+
27
+ @rbln_cache_update.register_fake
28
+ def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
+ return torch.empty_like(cache)
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
23
23
  output_shape = list(input.shape[:-1])
24
24
  output_shape += [weight.shape[0]]
25
25
  return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
26
+
27
+
28
+ @linear.register_fake
29
+ def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
30
+ output_shape = list(input.shape[:-1])
31
+ output_shape += [weight.shape[0]]
32
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -34,6 +34,8 @@ _import_structure = {
34
34
  "RBLNAutoModelForSequenceClassification",
35
35
  "RBLNAutoModelForSpeechSeq2Seq",
36
36
  "RBLNAutoModelForVision2Seq",
37
+ "RBLNAutoModelForTextEncoding",
38
+ "RBLNAutoModelForZeroShotObjectDetection",
37
39
  "RBLNBartForConditionalGeneration",
38
40
  "RBLNBartForConditionalGenerationConfig",
39
41
  "RBLNBartModel",
@@ -52,6 +54,8 @@ _import_structure = {
52
54
  "RBLNBlip2VisionModelConfig",
53
55
  "RBLNColPaliForRetrieval",
54
56
  "RBLNColPaliForRetrievalConfig",
57
+ "RBLNColQwen2ForRetrieval",
58
+ "RBLNColQwen2ForRetrievalConfig",
55
59
  "RBLNCLIPTextModel",
56
60
  "RBLNCLIPTextModelConfig",
57
61
  "RBLNCLIPTextModelWithProjection",
@@ -62,12 +66,18 @@ _import_structure = {
62
66
  "RBLNCLIPVisionModelWithProjectionConfig",
63
67
  "RBLNDecoderOnlyModelForCausalLM",
64
68
  "RBLNDecoderOnlyModelForCausalLMConfig",
69
+ "RBLNDecoderOnlyModelConfig",
70
+ "RBLNDecoderOnlyModel",
65
71
  "RBLNDistilBertForQuestionAnswering",
66
72
  "RBLNDistilBertForQuestionAnsweringConfig",
67
73
  "RBLNDPTForDepthEstimation",
68
74
  "RBLNDPTForDepthEstimationConfig",
75
+ "RBLNDepthAnythingForDepthEstimation",
76
+ "RBLNDepthAnythingForDepthEstimationConfig",
69
77
  "RBLNExaoneForCausalLM",
70
78
  "RBLNExaoneForCausalLMConfig",
79
+ "RBLNGemmaModel",
80
+ "RBLNGemmaModelConfig",
71
81
  "RBLNGemma3ForCausalLM",
72
82
  "RBLNGemma3ForCausalLMConfig",
73
83
  "RBLNGemma3ForConditionalGeneration",
@@ -76,26 +86,60 @@ _import_structure = {
76
86
  "RBLNGemmaForCausalLMConfig",
77
87
  "RBLNGPT2LMHeadModel",
78
88
  "RBLNGPT2LMHeadModelConfig",
89
+ "RBLNGPT2Model",
90
+ "RBLNGPT2ModelConfig",
91
+ "RBLNGroundingDinoDecoder",
92
+ "RBLNGroundingDinoDecoderConfig",
93
+ "RBLNGroundingDinoForObjectDetection",
94
+ "RBLNGroundingDinoForObjectDetectionConfig",
95
+ "RBLNGroundingDinoEncoder",
96
+ "RBLNGroundingDinoEncoderConfig",
79
97
  "RBLNIdefics3ForConditionalGeneration",
80
98
  "RBLNIdefics3ForConditionalGenerationConfig",
81
99
  "RBLNIdefics3VisionTransformer",
82
100
  "RBLNIdefics3VisionTransformerConfig",
83
101
  "RBLNLlamaForCausalLM",
84
102
  "RBLNLlamaForCausalLMConfig",
103
+ "RBLNLlavaForConditionalGeneration",
104
+ "RBLNLlavaForConditionalGenerationConfig",
105
+ "RBLNLlamaModel",
106
+ "RBLNLlamaModelConfig",
107
+ "RBLNOPTForCausalLM",
108
+ "RBLNOPTForCausalLMConfig",
109
+ "RBLNPegasusForConditionalGeneration",
110
+ "RBLNPegasusForConditionalGenerationConfig",
111
+ "RBLNPegasusModel",
112
+ "RBLNPegasusModelConfig",
85
113
  "RBLNLlavaNextForConditionalGeneration",
86
114
  "RBLNLlavaNextForConditionalGenerationConfig",
115
+ "RBLNLoRAAdapterConfig",
116
+ "RBLNLoRAConfig",
87
117
  "RBLNMidmLMHeadModel",
88
118
  "RBLNMidmLMHeadModelConfig",
89
119
  "RBLNMistralForCausalLM",
90
120
  "RBLNMistralForCausalLMConfig",
121
+ "RBLNMistralModel",
122
+ "RBLNMistralModelConfig",
91
123
  "RBLNOPTForCausalLM",
92
124
  "RBLNOPTForCausalLMConfig",
125
+ "RBLNOPTModel",
126
+ "RBLNOPTModelConfig",
93
127
  "RBLNPhiForCausalLM",
94
128
  "RBLNPhiForCausalLMConfig",
129
+ "RBLNPixtralVisionModelConfig",
130
+ "RBLNPixtralVisionModel",
131
+ "RBLNPhiModel",
132
+ "RBLNPhiModelConfig",
95
133
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
96
134
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
97
135
  "RBLNQwen2_5_VLForConditionalGeneration",
98
136
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
137
+ "RBLNQwen2VisionTransformerPretrainedModel",
138
+ "RBLNQwen2VisionTransformerPretrainedModelConfig",
139
+ "RBLNQwen2VLForConditionalGeneration",
140
+ "RBLNQwen2VLForConditionalGenerationConfig",
141
+ "RBLNQwen2Model",
142
+ "RBLNQwen2ModelConfig",
99
143
  "RBLNQwen2ForCausalLM",
100
144
  "RBLNQwen2ForCausalLMConfig",
101
145
  "RBLNQwen3ForCausalLM",
@@ -110,6 +154,8 @@ _import_structure = {
110
154
  "RBLNRobertaForSequenceClassificationConfig",
111
155
  "RBLNSiglipVisionModel",
112
156
  "RBLNSiglipVisionModelConfig",
157
+ "RBLNSwinBackbone",
158
+ "RBLNSwinBackboneConfig",
113
159
  "RBLNT5EncoderModel",
114
160
  "RBLNT5EncoderModelConfig",
115
161
  "RBLNT5ForConditionalGeneration",
@@ -145,7 +191,9 @@ if TYPE_CHECKING:
145
191
  RBLNAutoModelForSeq2SeqLM,
146
192
  RBLNAutoModelForSequenceClassification,
147
193
  RBLNAutoModelForSpeechSeq2Seq,
194
+ RBLNAutoModelForTextEncoding,
148
195
  RBLNAutoModelForVision2Seq,
196
+ RBLNAutoModelForZeroShotObjectDetection,
149
197
  RBLNBartForConditionalGeneration,
150
198
  RBLNBartForConditionalGenerationConfig,
151
199
  RBLNBartModel,
@@ -170,8 +218,16 @@ if TYPE_CHECKING:
170
218
  RBLNCLIPVisionModelConfig,
171
219
  RBLNCLIPVisionModelWithProjection,
172
220
  RBLNCLIPVisionModelWithProjectionConfig,
221
+ RBLNColPaliForRetrieval,
222
+ RBLNColPaliForRetrievalConfig,
223
+ RBLNColQwen2ForRetrieval,
224
+ RBLNColQwen2ForRetrievalConfig,
225
+ RBLNDecoderOnlyModel,
226
+ RBLNDecoderOnlyModelConfig,
173
227
  RBLNDecoderOnlyModelForCausalLM,
174
228
  RBLNDecoderOnlyModelForCausalLMConfig,
229
+ RBLNDepthAnythingForDepthEstimation,
230
+ RBLNDepthAnythingForDepthEstimationConfig,
175
231
  RBLNDistilBertForQuestionAnswering,
176
232
  RBLNDistilBertForQuestionAnsweringConfig,
177
233
  RBLNDPTForDepthEstimation,
@@ -184,30 +240,64 @@ if TYPE_CHECKING:
184
240
  RBLNGemma3ForConditionalGenerationConfig,
185
241
  RBLNGemmaForCausalLM,
186
242
  RBLNGemmaForCausalLMConfig,
243
+ RBLNGemmaModel,
244
+ RBLNGemmaModelConfig,
187
245
  RBLNGPT2LMHeadModel,
188
246
  RBLNGPT2LMHeadModelConfig,
247
+ RBLNGPT2Model,
248
+ RBLNGPT2ModelConfig,
249
+ RBLNGroundingDinoDecoder,
250
+ RBLNGroundingDinoDecoderConfig,
251
+ RBLNGroundingDinoEncoder,
252
+ RBLNGroundingDinoEncoderConfig,
253
+ RBLNGroundingDinoForObjectDetection,
254
+ RBLNGroundingDinoForObjectDetectionConfig,
189
255
  RBLNIdefics3ForConditionalGeneration,
190
256
  RBLNIdefics3ForConditionalGenerationConfig,
191
257
  RBLNIdefics3VisionTransformer,
192
258
  RBLNIdefics3VisionTransformerConfig,
193
259
  RBLNLlamaForCausalLM,
194
260
  RBLNLlamaForCausalLMConfig,
261
+ RBLNLlamaModel,
262
+ RBLNLlamaModelConfig,
263
+ RBLNLlavaForConditionalGeneration,
264
+ RBLNLlavaForConditionalGenerationConfig,
195
265
  RBLNLlavaNextForConditionalGeneration,
196
266
  RBLNLlavaNextForConditionalGenerationConfig,
267
+ RBLNLoRAAdapterConfig,
268
+ RBLNLoRAConfig,
197
269
  RBLNMidmLMHeadModel,
198
270
  RBLNMidmLMHeadModelConfig,
199
271
  RBLNMistralForCausalLM,
200
272
  RBLNMistralForCausalLMConfig,
273
+ RBLNMistralModel,
274
+ RBLNMistralModelConfig,
201
275
  RBLNOPTForCausalLM,
202
276
  RBLNOPTForCausalLMConfig,
277
+ RBLNOPTModel,
278
+ RBLNOPTModelConfig,
279
+ RBLNPegasusForConditionalGeneration,
280
+ RBLNPegasusForConditionalGenerationConfig,
281
+ RBLNPegasusModel,
282
+ RBLNPegasusModelConfig,
203
283
  RBLNPhiForCausalLM,
204
284
  RBLNPhiForCausalLMConfig,
285
+ RBLNPhiModel,
286
+ RBLNPhiModelConfig,
287
+ RBLNPixtralVisionModel,
288
+ RBLNPixtralVisionModelConfig,
205
289
  RBLNQwen2_5_VisionTransformerPretrainedModel,
206
290
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
207
291
  RBLNQwen2_5_VLForConditionalGeneration,
208
292
  RBLNQwen2_5_VLForConditionalGenerationConfig,
209
293
  RBLNQwen2ForCausalLM,
210
294
  RBLNQwen2ForCausalLMConfig,
295
+ RBLNQwen2Model,
296
+ RBLNQwen2ModelConfig,
297
+ RBLNQwen2VisionTransformerPretrainedModel,
298
+ RBLNQwen2VisionTransformerPretrainedModelConfig,
299
+ RBLNQwen2VLForConditionalGeneration,
300
+ RBLNQwen2VLForConditionalGenerationConfig,
211
301
  RBLNQwen3ForCausalLM,
212
302
  RBLNQwen3ForCausalLMConfig,
213
303
  RBLNQwen3Model,
@@ -220,6 +310,8 @@ if TYPE_CHECKING:
220
310
  RBLNRobertaForSequenceClassificationConfig,
221
311
  RBLNSiglipVisionModel,
222
312
  RBLNSiglipVisionModelConfig,
313
+ RBLNSwinBackbone,
314
+ RBLNSwinBackboneConfig,
223
315
  RBLNT5EncoderModel,
224
316
  RBLNT5EncoderModelConfig,
225
317
  RBLNT5ForConditionalGeneration,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from typing import Any, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
@@ -25,7 +25,8 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs: Dict[str, Any],
28
+ model_input_shapes: Optional[List[Tuple[int, int]]] = None,
29
+ **kwargs: Any,
29
30
  ):
30
31
  """
31
32
  Args:
@@ -33,7 +34,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
33
34
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
35
  model_input_names (Optional[List[str]]): Names of the input tensors for the model.
35
36
  Defaults to class-specific rbln_model_input_names if not provided.
36
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
38
 
38
39
  Raises:
39
40
  ValueError: If batch_size is not a positive integer.
@@ -45,6 +46,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
45
46
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
47
 
47
48
  self.model_input_names = model_input_names or self.rbln_model_input_names
49
+ self.model_input_shapes = model_input_shapes
48
50
 
49
51
 
50
52
  class RBLNImageModelConfig(RBLNModelConfig):
@@ -52,14 +54,14 @@ class RBLNImageModelConfig(RBLNModelConfig):
52
54
  self,
53
55
  image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
56
  batch_size: Optional[int] = None,
55
- **kwargs: Dict[str, Any],
57
+ **kwargs: Any,
56
58
  ):
57
59
  """
58
60
  Args:
59
61
  image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
60
62
  Can be an integer for square images or a tuple (height, width).
61
63
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
62
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
64
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
63
65
 
64
66
  Raises:
65
67
  ValueError: If batch_size is not a positive integer.
@@ -116,30 +118,3 @@ class RBLNModelForImageClassificationConfig(RBLNImageModelConfig):
116
118
 
117
119
  class RBLNModelForDepthEstimationConfig(RBLNImageModelConfig):
118
120
  pass
119
-
120
-
121
- class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
122
- def __init__(
123
- self,
124
- batch_size: Optional[int] = None,
125
- max_length: Optional[int] = None,
126
- num_mel_bins: Optional[int] = None,
127
- **kwargs: Dict[str, Any],
128
- ):
129
- """
130
- Args:
131
- batch_size (Optional[int]): The batch size for inference. Defaults to 1.
132
- max_length (Optional[int]): Maximum length of the audio input in time dimension.
133
- num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
134
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
135
-
136
- Raises:
137
- ValueError: If batch_size is not a positive integer.
138
- """
139
- super().__init__(**kwargs)
140
- self.batch_size = batch_size or 1
141
- if not isinstance(self.batch_size, int) or self.batch_size < 0:
142
- raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
143
-
144
- self.max_length = max_length
145
- self.num_mel_bins = num_mel_bins