optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (90) hide show
  1. optimum/rbln/__init__.py +8 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +6 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +10 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -257
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  81. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  82. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  83. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  84. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  85. optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
  86. optimum/rbln/utils/runtime_utils.py +3 -3
  87. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
  88. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
  89. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
  90. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling.py CHANGED
@@ -78,7 +78,7 @@ class RBLNModel(RBLNBaseModel):
78
78
  rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
79
79
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
80
80
  subfolder: str = "",
81
- **kwargs: Dict[str, Any],
81
+ **kwargs: Any,
82
82
  ) -> "RBLNModel":
83
83
  """
84
84
  Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
@@ -241,7 +241,7 @@ class RBLNModel(RBLNBaseModel):
241
241
  for compiled_model in compiled_models
242
242
  ]
243
243
 
244
- def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
244
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
245
245
  """
246
246
  Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
247
247
 
@@ -348,7 +348,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
348
348
  model_id: Union[str, Path],
349
349
  export: bool = False,
350
350
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
351
- **kwargs: Dict[str, Any],
351
+ **kwargs: Any,
352
352
  ) -> "RBLNBaseModel":
353
353
  """
354
354
  The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
@@ -523,10 +523,18 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
523
523
  # First copy everything to a temporary directory
524
524
  shutil.copytree(real_save_dir, tmp_dir)
525
525
 
526
- # If everything succeeded, atomically replace the target directory
526
+ # If everything succeeded, move files to target directory
527
527
  if os.path.exists(save_directory_path):
528
- shutil.rmtree(save_directory_path)
529
- os.rename(tmp_dir, save_directory_path)
528
+ # Move files from tmp_dir to existing directory (overwrite existing files)
529
+ for item in os.listdir(tmp_dir):
530
+ src_path = os.path.join(tmp_dir, item)
531
+ dst_path = os.path.join(save_directory_path, item)
532
+ shutil.move(src_path, dst_path)
533
+ # Clean up empty tmp_dir
534
+ os.rmdir(tmp_dir)
535
+ else:
536
+ # If target doesn't exist, just rename tmp_dir to target
537
+ os.rename(tmp_dir, save_directory_path)
530
538
 
531
539
  except Exception as e:
532
540
  # Clean up the temporary directory if anything fails
optimum/rbln/ops/attn.py CHANGED
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
53
53
  return torch.empty_like(q)
54
54
 
55
55
 
56
+ @torch.library.custom_op(
57
+ "rbln_custom_ops::paged_attn_decode_kv_fp8",
58
+ mutates_args=(["kcache", "vcache"]),
59
+ )
60
+ def paged_attn_decode_kv_fp8(
61
+ q: Tensor,
62
+ k: Tensor,
63
+ v: Tensor,
64
+ mask: Tensor,
65
+ kcache: Tensor,
66
+ vcache: Tensor,
67
+ seq: Tensor,
68
+ scale: Tensor,
69
+ block_table: Tensor,
70
+ block_size: int,
71
+ k_scale: Tensor,
72
+ v_scale: Tensor,
73
+ ) -> Tensor:
74
+ return torch.empty_like(q)
75
+
76
+
77
+ @paged_attn_decode_kv_fp8.register_fake
78
+ def paged_attn_decode_kv_fp8_fake(
79
+ q: Tensor,
80
+ k: Tensor,
81
+ v: Tensor,
82
+ mask: Tensor,
83
+ kcache: Tensor,
84
+ vcache: Tensor,
85
+ seq: Tensor,
86
+ scale: Tensor,
87
+ block_table: Tensor,
88
+ block_size: int,
89
+ k_scale: Tensor,
90
+ v_scale: Tensor,
91
+ ) -> Tensor:
92
+ return torch.empty_like(q)
93
+
94
+
56
95
  @torch.library.custom_op(
57
96
  "rbln_custom_ops::paged_attn_prefill",
58
97
  mutates_args=(["kcache", "vcache"]),
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
112
151
  return torch.empty_like(q)
113
152
 
114
153
 
154
+ @torch.library.custom_op(
155
+ "rbln_custom_ops::paged_attn_prefill_kv_fp8",
156
+ mutates_args=(["kcache", "vcache"]),
157
+ )
158
+ def paged_attn_prefill_kv_fp8(
159
+ q: Tensor,
160
+ k: Tensor,
161
+ v: Tensor,
162
+ mask: Tensor,
163
+ kcache: Tensor,
164
+ vcache: Tensor,
165
+ seq: Tensor,
166
+ scale: Tensor,
167
+ block_table: Tensor,
168
+ block_size: int,
169
+ k_scale: Tensor,
170
+ v_scale: Tensor,
171
+ ) -> Tensor:
172
+ return torch.empty_like(q)
173
+
174
+
175
+ @paged_attn_prefill_kv_fp8.register_fake
176
+ def paged_attn_prefill_kv_fp8_fake(
177
+ q: Tensor,
178
+ k: Tensor,
179
+ v: Tensor,
180
+ mask: Tensor,
181
+ kcache: Tensor,
182
+ vcache: Tensor,
183
+ seq: Tensor,
184
+ scale: Tensor,
185
+ block_table: Tensor,
186
+ block_size: int,
187
+ k_scale: Tensor,
188
+ v_scale: Tensor,
189
+ ) -> Tensor:
190
+ return torch.empty_like(q)
191
+
192
+
115
193
  @torch.library.custom_op(
116
194
  "rbln_custom_ops::paged_causal_attn_decode",
117
195
  mutates_args=(["kcache", "vcache"]),
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
236
314
  return torch.empty_like(q)
237
315
 
238
316
 
317
+ @torch.library.custom_op(
318
+ "rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
319
+ mutates_args=(["kcache", "vcache"]),
320
+ )
321
+ def paged_causal_attn_decode_kv_fp8(
322
+ q: Tensor,
323
+ k: Tensor,
324
+ v: Tensor,
325
+ kcache: Tensor,
326
+ vcache: Tensor,
327
+ seq: Tensor,
328
+ scale: Tensor,
329
+ block_table: Tensor,
330
+ block_size: int,
331
+ k_scale: Tensor,
332
+ v_scale: Tensor,
333
+ mask: Optional[Tensor] = None,
334
+ ) -> Tensor:
335
+ return torch.empty_like(q)
336
+
337
+
338
+ @paged_causal_attn_decode_kv_fp8.register_fake
339
+ def paged_causal_attn_decode_kv_fp8_fake(
340
+ q: Tensor,
341
+ k: Tensor,
342
+ v: Tensor,
343
+ kcache: Tensor,
344
+ vcache: Tensor,
345
+ seq: Tensor,
346
+ scale: Tensor,
347
+ block_table: Tensor,
348
+ block_size: int,
349
+ k_scale: Tensor,
350
+ v_scale: Tensor,
351
+ mask: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ return torch.empty_like(q)
354
+
355
+
356
+ @torch.library.custom_op(
357
+ "rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
358
+ mutates_args=(["kcache", "vcache"]),
359
+ )
360
+ def paged_causal_attn_prefill_kv_fp8(
361
+ q: Tensor,
362
+ k: Tensor,
363
+ v: Tensor,
364
+ kcache: Tensor,
365
+ vcache: Tensor,
366
+ seq: Tensor,
367
+ scale: Tensor,
368
+ block_table: Tensor,
369
+ block_size: int,
370
+ is_bidirectional: bool,
371
+ k_scale: Tensor,
372
+ v_scale: Tensor,
373
+ mask: Optional[Tensor] = None,
374
+ ) -> Tensor:
375
+ return torch.empty_like(q)
376
+
377
+
378
+ @paged_causal_attn_prefill_kv_fp8.register_fake
379
+ def paged_causal_attn_prefill_kv_fp8_fake(
380
+ q: Tensor,
381
+ k: Tensor,
382
+ v: Tensor,
383
+ kcache: Tensor,
384
+ vcache: Tensor,
385
+ seq: Tensor,
386
+ scale: Tensor,
387
+ block_table: Tensor,
388
+ block_size: int,
389
+ is_bidirectional: bool,
390
+ k_scale: Tensor,
391
+ v_scale: Tensor,
392
+ mask: Optional[Tensor] = None,
393
+ ) -> Tensor:
394
+ return torch.empty_like(q)
395
+
396
+
239
397
  @torch.library.custom_op(
240
398
  "rbln_custom_ops::paged_add_softmax_attn_decode",
241
399
  mutates_args=(["kcache", "vcache"]),
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
59
59
  return torch.empty_like(q)
60
60
 
61
61
 
62
+ @torch.library.custom_op(
63
+ "rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
64
+ mutates_args=(["kcache", "vcache"]),
65
+ )
66
+ def paged_flash_attn_decode_kv_fp8(
67
+ q: Tensor,
68
+ k: Tensor,
69
+ v: Tensor,
70
+ mask: Tensor,
71
+ kcache: Tensor,
72
+ vcache: Tensor,
73
+ seq: Tensor,
74
+ scale: Tensor,
75
+ block_table: Tensor,
76
+ block_size: int,
77
+ partition: int,
78
+ k_scale: Tensor,
79
+ v_scale: Tensor,
80
+ ) -> Tensor:
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_decode_kv_fp8.register_fake
85
+ def paged_flash_attn_decode_kv_fp8_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ k_scale: Tensor,
98
+ v_scale: Tensor,
99
+ ) -> Tensor:
100
+ return torch.empty_like(q)
101
+
102
+
62
103
  @torch.library.custom_op(
63
104
  "rbln_custom_ops::paged_flash_attn_prefill",
64
105
  mutates_args=(["kcache", "vcache"]),
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
100
141
  return torch.empty_like(q)
101
142
 
102
143
 
144
+ @torch.library.custom_op(
145
+ "rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
146
+ mutates_args=(["kcache", "vcache"]),
147
+ )
148
+ def paged_flash_attn_prefill_kv_fp8(
149
+ q: Tensor,
150
+ k: Tensor,
151
+ v: Tensor,
152
+ mask: Tensor,
153
+ kcache: Tensor,
154
+ vcache: Tensor,
155
+ seq: Tensor,
156
+ scale: Tensor,
157
+ block_table: Tensor,
158
+ block_size: int,
159
+ partition: int,
160
+ k_scale: Tensor,
161
+ v_scale: Tensor,
162
+ ) -> Tensor:
163
+ return torch.empty_like(q)
164
+
165
+
166
+ @paged_flash_attn_prefill_kv_fp8.register_fake
167
+ def paged_flash_attn_prefill_kv_fp8_fake(
168
+ q: Tensor,
169
+ k: Tensor,
170
+ v: Tensor,
171
+ mask: Tensor,
172
+ kcache: Tensor,
173
+ vcache: Tensor,
174
+ seq: Tensor,
175
+ scale: Tensor,
176
+ block_table: Tensor,
177
+ block_size: int,
178
+ partition: int,
179
+ k_scale: Tensor,
180
+ v_scale: Tensor,
181
+ ) -> Tensor:
182
+ return torch.empty_like(q)
183
+
184
+
103
185
  @torch.library.custom_op(
104
186
  "rbln_custom_ops::paged_flash_causal_attn_decode",
105
187
  mutates_args=(["kcache", "vcache"]),
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
141
223
  return torch.empty_like(q)
142
224
 
143
225
 
226
+ @torch.library.custom_op(
227
+ "rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
228
+ mutates_args=(["kcache", "vcache"]),
229
+ )
230
+ def paged_flash_causal_attn_decode_kv_fp8(
231
+ q: Tensor,
232
+ k: Tensor,
233
+ v: Tensor,
234
+ kcache: Tensor,
235
+ vcache: Tensor,
236
+ seq: Tensor,
237
+ scale: Tensor,
238
+ block_table: Tensor,
239
+ block_size: int,
240
+ partition: int,
241
+ k_scale: Tensor,
242
+ v_scale: Tensor,
243
+ mask: Optional[Tensor] = None,
244
+ ) -> Tensor:
245
+ return torch.empty_like(q)
246
+
247
+
248
+ @paged_flash_causal_attn_decode_kv_fp8.register_fake
249
+ def paged_flash_causal_attn_decode_kv_fp8_fake(
250
+ q: Tensor,
251
+ k: Tensor,
252
+ v: Tensor,
253
+ kcache: Tensor,
254
+ vcache: Tensor,
255
+ seq: Tensor,
256
+ scale: Tensor,
257
+ block_table: Tensor,
258
+ block_size: int,
259
+ partition: int,
260
+ k_scale: Tensor,
261
+ v_scale: Tensor,
262
+ mask: Optional[Tensor] = None,
263
+ ) -> Tensor:
264
+ return torch.empty_like(q)
265
+
266
+
144
267
  @torch.library.custom_op(
145
268
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
146
269
  mutates_args=(["kcache", "vcache"]),
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
182
305
  mask: Optional[Tensor] = None,
183
306
  ) -> Tensor:
184
307
  return torch.empty_like(q)
308
+
309
+
310
+ @torch.library.custom_op(
311
+ "rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
312
+ mutates_args=(["kcache", "vcache"]),
313
+ )
314
+ def paged_flash_causal_attn_prefill_kv_fp8(
315
+ q: Tensor,
316
+ k: Tensor,
317
+ v: Tensor,
318
+ kcache: Tensor,
319
+ vcache: Tensor,
320
+ seq: Tensor,
321
+ scale: Tensor,
322
+ block_table: Tensor,
323
+ block_size: int,
324
+ partition: int,
325
+ is_bidirectional: bool,
326
+ k_scale: Tensor,
327
+ v_scale: Tensor,
328
+ mask: Optional[Tensor] = None,
329
+ ) -> Tensor:
330
+ return torch.empty_like(q)
331
+
332
+
333
+ @paged_flash_causal_attn_prefill_kv_fp8.register_fake
334
+ def paged_flash_causal_attn_prefill_kv_fp8_fake(
335
+ q: Tensor,
336
+ k: Tensor,
337
+ v: Tensor,
338
+ kcache: Tensor,
339
+ vcache: Tensor,
340
+ seq: Tensor,
341
+ scale: Tensor,
342
+ block_table: Tensor,
343
+ block_size: int,
344
+ partition: int,
345
+ is_bidirectional: bool,
346
+ k_scale: Tensor,
347
+ v_scale: Tensor,
348
+ mask: Optional[Tensor] = None,
349
+ ) -> Tensor:
350
+ return torch.empty_like(q)
@@ -34,6 +34,7 @@ _import_structure = {
34
34
  "RBLNAutoModelForSequenceClassification",
35
35
  "RBLNAutoModelForSpeechSeq2Seq",
36
36
  "RBLNAutoModelForVision2Seq",
37
+ "RBLNAutoModelForTextEncoding",
37
38
  "RBLNBartForConditionalGeneration",
38
39
  "RBLNBartForConditionalGenerationConfig",
39
40
  "RBLNBartModel",
@@ -62,6 +63,8 @@ _import_structure = {
62
63
  "RBLNCLIPVisionModelWithProjectionConfig",
63
64
  "RBLNDecoderOnlyModelForCausalLM",
64
65
  "RBLNDecoderOnlyModelForCausalLMConfig",
66
+ "RBLNDecoderOnlyModelConfig",
67
+ "RBLNDecoderOnlyModel",
65
68
  "RBLNDistilBertForQuestionAnswering",
66
69
  "RBLNDistilBertForQuestionAnsweringConfig",
67
70
  "RBLNDPTForDepthEstimation",
@@ -169,6 +172,7 @@ if TYPE_CHECKING:
169
172
  RBLNAutoModelForSeq2SeqLM,
170
173
  RBLNAutoModelForSequenceClassification,
171
174
  RBLNAutoModelForSpeechSeq2Seq,
175
+ RBLNAutoModelForTextEncoding,
172
176
  RBLNAutoModelForVision2Seq,
173
177
  RBLNBartForConditionalGeneration,
174
178
  RBLNBartForConditionalGenerationConfig,
@@ -196,6 +200,8 @@ if TYPE_CHECKING:
196
200
  RBLNCLIPVisionModelWithProjectionConfig,
197
201
  RBLNColPaliForRetrieval,
198
202
  RBLNColPaliForRetrievalConfig,
203
+ RBLNDecoderOnlyModel,
204
+ RBLNDecoderOnlyModelConfig,
199
205
  RBLNDecoderOnlyModelForCausalLM,
200
206
  RBLNDecoderOnlyModelForCausalLMConfig,
201
207
  RBLNDistilBertForQuestionAnswering,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from typing import Any, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
@@ -25,7 +25,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs: Dict[str, Any],
28
+ **kwargs: Any,
29
29
  ):
30
30
  """
31
31
  Args:
@@ -52,7 +52,7 @@ class RBLNImageModelConfig(RBLNModelConfig):
52
52
  self,
53
53
  image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
54
  batch_size: Optional[int] = None,
55
- **kwargs: Dict[str, Any],
55
+ **kwargs: Any,
56
56
  ):
57
57
  """
58
58
  Args:
@@ -124,7 +124,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
124
124
  batch_size: Optional[int] = None,
125
125
  max_length: Optional[int] = None,
126
126
  num_mel_bins: Optional[int] = None,
127
- **kwargs: Dict[str, Any],
127
+ **kwargs: Any,
128
128
  ):
129
129
  """
130
130
  Args:
@@ -34,10 +34,7 @@ from transformers import (
34
34
  AutoModelForTextEncoding,
35
35
  PretrainedConfig,
36
36
  )
37
- from transformers.modeling_outputs import (
38
- BaseModelOutput,
39
- QuestionAnsweringModelOutput,
40
- )
37
+ from transformers.modeling_outputs import BaseModelOutput, QuestionAnsweringModelOutput
41
38
 
42
39
  from ..configuration_utils import RBLNCompileConfig
43
40
  from ..modeling import RBLNModel
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ModelOutput
20
+
21
+
22
+ @dataclass
23
+ class RBLNDecoderOnlyOutput(ModelOutput):
24
+ logits: torch.FloatTensor = None
25
+ generate_idx: torch.Tensor = None
26
+ padded_cache_lengths: int = None
27
+
28
+
29
+ @dataclass
30
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
31
+ attention_mask: Optional[torch.Tensor] = None
32
+
33
+
34
+ @dataclass
35
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
36
+ last_hidden_states: torch.FloatTensor = None
37
+ params: Tuple[torch.FloatTensor] = None
@@ -36,6 +36,7 @@ _import_structure = {
36
36
  "RBLNAutoModelForSpeechSeq2Seq",
37
37
  "RBLNAutoModelForVision2Seq",
38
38
  "RBLNAutoModelForImageTextToText",
39
+ "RBLNAutoModelForTextEncoding",
39
40
  ],
40
41
  "bart": [
41
42
  "RBLNBartForConditionalGeneration",
@@ -84,6 +85,8 @@ _import_structure = {
84
85
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
85
86
  ],
86
87
  "decoderonly": [
88
+ "RBLNDecoderOnlyModelConfig",
89
+ "RBLNDecoderOnlyModel",
87
90
  "RBLNDecoderOnlyModelForCausalLM",
88
91
  "RBLNDecoderOnlyModelForCausalLMConfig",
89
92
  ],
@@ -160,10 +163,7 @@ _import_structure = {
160
163
  }
161
164
 
162
165
  if TYPE_CHECKING:
163
- from .audio_spectrogram_transformer import (
164
- RBLNASTForAudioClassification,
165
- RBLNASTForAudioClassificationConfig,
166
- )
166
+ from .audio_spectrogram_transformer import RBLNASTForAudioClassification, RBLNASTForAudioClassificationConfig
167
167
  from .auto import (
168
168
  RBLNAutoModel,
169
169
  RBLNAutoModelForAudioClassification,
@@ -177,6 +177,7 @@ if TYPE_CHECKING:
177
177
  RBLNAutoModelForSeq2SeqLM,
178
178
  RBLNAutoModelForSequenceClassification,
179
179
  RBLNAutoModelForSpeechSeq2Seq,
180
+ RBLNAutoModelForTextEncoding,
180
181
  RBLNAutoModelForVision2Seq,
181
182
  )
182
183
  from .bart import (
@@ -211,22 +212,15 @@ if TYPE_CHECKING:
211
212
  RBLNCLIPVisionModelWithProjection,
212
213
  RBLNCLIPVisionModelWithProjectionConfig,
213
214
  )
214
- from .colpali import (
215
- RBLNColPaliForRetrieval,
216
- RBLNColPaliForRetrievalConfig,
217
- )
215
+ from .colpali import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
218
216
  from .decoderonly import (
217
+ RBLNDecoderOnlyModel,
218
+ RBLNDecoderOnlyModelConfig,
219
219
  RBLNDecoderOnlyModelForCausalLM,
220
220
  RBLNDecoderOnlyModelForCausalLMConfig,
221
221
  )
222
- from .distilbert import (
223
- RBLNDistilBertForQuestionAnswering,
224
- RBLNDistilBertForQuestionAnsweringConfig,
225
- )
226
- from .dpt import (
227
- RBLNDPTForDepthEstimation,
228
- RBLNDPTForDepthEstimationConfig,
229
- )
222
+ from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
223
+ from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
230
224
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
231
225
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
232
226
  from .gemma3 import (
@@ -25,5 +25,6 @@ from .modeling_auto import (
25
25
  RBLNAutoModelForSeq2SeqLM,
26
26
  RBLNAutoModelForSequenceClassification,
27
27
  RBLNAutoModelForSpeechSeq2Seq,
28
+ RBLNAutoModelForTextEncoding,
28
29
  RBLNAutoModelForVision2Seq,
29
30
  )
@@ -35,6 +35,8 @@ from transformers.models.auto.modeling_auto import (
35
35
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
36
36
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
37
37
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
38
+ MODEL_FOR_TEXT_ENCODING_MAPPING,
39
+ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
38
40
  MODEL_FOR_VISION_2_SEQ_MAPPING,
39
41
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
40
42
  MODEL_MAPPING,
@@ -115,3 +117,8 @@ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
115
117
  class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
116
118
  _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
117
119
  _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
120
+
121
+
122
+ class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
123
+ _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
124
+ _model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
@@ -16,9 +16,7 @@ from typing import Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_attention_mask,
21
- )
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
22
20
  from transformers.utils import logging
23
21
 
24
22
  from ..seq2seq.seq2seq_architecture import (
@@ -32,3 +32,5 @@ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized BART models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = True
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -62,7 +62,7 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
62
62
  vision_model: Optional[RBLNModelConfig] = None,
63
63
  qformer: Optional[RBLNModelConfig] = None,
64
64
  language_model: Optional[RBLNModelConfig] = None,
65
- **kwargs: Dict[str, Any],
65
+ **kwargs: Any,
66
66
  ):
67
67
  """
68
68
  Args:
@@ -35,11 +35,7 @@ from ....modeling import RBLNModel
35
35
  logger = logging.get_logger(__name__)
36
36
 
37
37
  if TYPE_CHECKING:
38
- from transformers import (
39
- AutoFeatureExtractor,
40
- AutoProcessor,
41
- AutoTokenizer,
42
- )
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
43
39
 
44
40
 
45
41
  class LoopProjector: