optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +36 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +28 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +35 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +64 -258
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
59
59
  return torch.empty_like(q)
60
60
 
61
61
 
62
+ @torch.library.custom_op(
63
+ "rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
64
+ mutates_args=(["kcache", "vcache"]),
65
+ )
66
+ def paged_flash_attn_decode_kv_fp8(
67
+ q: Tensor,
68
+ k: Tensor,
69
+ v: Tensor,
70
+ mask: Tensor,
71
+ kcache: Tensor,
72
+ vcache: Tensor,
73
+ seq: Tensor,
74
+ scale: Tensor,
75
+ block_table: Tensor,
76
+ block_size: int,
77
+ partition: int,
78
+ k_scale: Tensor,
79
+ v_scale: Tensor,
80
+ ) -> Tensor:
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_decode_kv_fp8.register_fake
85
+ def paged_flash_attn_decode_kv_fp8_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ k_scale: Tensor,
98
+ v_scale: Tensor,
99
+ ) -> Tensor:
100
+ return torch.empty_like(q)
101
+
102
+
62
103
  @torch.library.custom_op(
63
104
  "rbln_custom_ops::paged_flash_attn_prefill",
64
105
  mutates_args=(["kcache", "vcache"]),
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
100
141
  return torch.empty_like(q)
101
142
 
102
143
 
144
+ @torch.library.custom_op(
145
+ "rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
146
+ mutates_args=(["kcache", "vcache"]),
147
+ )
148
+ def paged_flash_attn_prefill_kv_fp8(
149
+ q: Tensor,
150
+ k: Tensor,
151
+ v: Tensor,
152
+ mask: Tensor,
153
+ kcache: Tensor,
154
+ vcache: Tensor,
155
+ seq: Tensor,
156
+ scale: Tensor,
157
+ block_table: Tensor,
158
+ block_size: int,
159
+ partition: int,
160
+ k_scale: Tensor,
161
+ v_scale: Tensor,
162
+ ) -> Tensor:
163
+ return torch.empty_like(q)
164
+
165
+
166
+ @paged_flash_attn_prefill_kv_fp8.register_fake
167
+ def paged_flash_attn_prefill_kv_fp8_fake(
168
+ q: Tensor,
169
+ k: Tensor,
170
+ v: Tensor,
171
+ mask: Tensor,
172
+ kcache: Tensor,
173
+ vcache: Tensor,
174
+ seq: Tensor,
175
+ scale: Tensor,
176
+ block_table: Tensor,
177
+ block_size: int,
178
+ partition: int,
179
+ k_scale: Tensor,
180
+ v_scale: Tensor,
181
+ ) -> Tensor:
182
+ return torch.empty_like(q)
183
+
184
+
103
185
  @torch.library.custom_op(
104
186
  "rbln_custom_ops::paged_flash_causal_attn_decode",
105
187
  mutates_args=(["kcache", "vcache"]),
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
141
223
  return torch.empty_like(q)
142
224
 
143
225
 
226
+ @torch.library.custom_op(
227
+ "rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
228
+ mutates_args=(["kcache", "vcache"]),
229
+ )
230
+ def paged_flash_causal_attn_decode_kv_fp8(
231
+ q: Tensor,
232
+ k: Tensor,
233
+ v: Tensor,
234
+ kcache: Tensor,
235
+ vcache: Tensor,
236
+ seq: Tensor,
237
+ scale: Tensor,
238
+ block_table: Tensor,
239
+ block_size: int,
240
+ partition: int,
241
+ k_scale: Tensor,
242
+ v_scale: Tensor,
243
+ mask: Optional[Tensor] = None,
244
+ ) -> Tensor:
245
+ return torch.empty_like(q)
246
+
247
+
248
+ @paged_flash_causal_attn_decode_kv_fp8.register_fake
249
+ def paged_flash_causal_attn_decode_kv_fp8_fake(
250
+ q: Tensor,
251
+ k: Tensor,
252
+ v: Tensor,
253
+ kcache: Tensor,
254
+ vcache: Tensor,
255
+ seq: Tensor,
256
+ scale: Tensor,
257
+ block_table: Tensor,
258
+ block_size: int,
259
+ partition: int,
260
+ k_scale: Tensor,
261
+ v_scale: Tensor,
262
+ mask: Optional[Tensor] = None,
263
+ ) -> Tensor:
264
+ return torch.empty_like(q)
265
+
266
+
144
267
  @torch.library.custom_op(
145
268
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
146
269
  mutates_args=(["kcache", "vcache"]),
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
182
305
  mask: Optional[Tensor] = None,
183
306
  ) -> Tensor:
184
307
  return torch.empty_like(q)
308
+
309
+
310
+ @torch.library.custom_op(
311
+ "rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
312
+ mutates_args=(["kcache", "vcache"]),
313
+ )
314
+ def paged_flash_causal_attn_prefill_kv_fp8(
315
+ q: Tensor,
316
+ k: Tensor,
317
+ v: Tensor,
318
+ kcache: Tensor,
319
+ vcache: Tensor,
320
+ seq: Tensor,
321
+ scale: Tensor,
322
+ block_table: Tensor,
323
+ block_size: int,
324
+ partition: int,
325
+ is_bidirectional: bool,
326
+ k_scale: Tensor,
327
+ v_scale: Tensor,
328
+ mask: Optional[Tensor] = None,
329
+ ) -> Tensor:
330
+ return torch.empty_like(q)
331
+
332
+
333
+ @paged_flash_causal_attn_prefill_kv_fp8.register_fake
334
+ def paged_flash_causal_attn_prefill_kv_fp8_fake(
335
+ q: Tensor,
336
+ k: Tensor,
337
+ v: Tensor,
338
+ kcache: Tensor,
339
+ vcache: Tensor,
340
+ seq: Tensor,
341
+ scale: Tensor,
342
+ block_table: Tensor,
343
+ block_size: int,
344
+ partition: int,
345
+ is_bidirectional: bool,
346
+ k_scale: Tensor,
347
+ v_scale: Tensor,
348
+ mask: Optional[Tensor] = None,
349
+ ) -> Tensor:
350
+ return torch.empty_like(q)
@@ -34,6 +34,8 @@ _import_structure = {
34
34
  "RBLNAutoModelForSequenceClassification",
35
35
  "RBLNAutoModelForSpeechSeq2Seq",
36
36
  "RBLNAutoModelForVision2Seq",
37
+ "RBLNAutoModelForTextEncoding",
38
+ "RBLNAutoModelForZeroShotObjectDetection",
37
39
  "RBLNBartForConditionalGeneration",
38
40
  "RBLNBartForConditionalGenerationConfig",
39
41
  "RBLNBartModel",
@@ -62,10 +64,14 @@ _import_structure = {
62
64
  "RBLNCLIPVisionModelWithProjectionConfig",
63
65
  "RBLNDecoderOnlyModelForCausalLM",
64
66
  "RBLNDecoderOnlyModelForCausalLMConfig",
67
+ "RBLNDecoderOnlyModelConfig",
68
+ "RBLNDecoderOnlyModel",
65
69
  "RBLNDistilBertForQuestionAnswering",
66
70
  "RBLNDistilBertForQuestionAnsweringConfig",
67
71
  "RBLNDPTForDepthEstimation",
68
72
  "RBLNDPTForDepthEstimationConfig",
73
+ "RBLNDepthAnythingForDepthEstimation",
74
+ "RBLNDepthAnythingForDepthEstimationConfig",
69
75
  "RBLNExaoneForCausalLM",
70
76
  "RBLNExaoneForCausalLMConfig",
71
77
  "RBLNGemmaModel",
@@ -80,6 +86,12 @@ _import_structure = {
80
86
  "RBLNGPT2LMHeadModelConfig",
81
87
  "RBLNGPT2Model",
82
88
  "RBLNGPT2ModelConfig",
89
+ "RBLNGroundingDinoDecoder",
90
+ "RBLNGroundingDinoDecoderConfig",
91
+ "RBLNGroundingDinoForObjectDetection",
92
+ "RBLNGroundingDinoForObjectDetectionConfig",
93
+ "RBLNGroundingDinoEncoder",
94
+ "RBLNGroundingDinoEncoderConfig",
83
95
  "RBLNIdefics3ForConditionalGeneration",
84
96
  "RBLNIdefics3ForConditionalGenerationConfig",
85
97
  "RBLNIdefics3VisionTransformer",
@@ -134,6 +146,8 @@ _import_structure = {
134
146
  "RBLNRobertaForSequenceClassificationConfig",
135
147
  "RBLNSiglipVisionModel",
136
148
  "RBLNSiglipVisionModelConfig",
149
+ "RBLNSwinBackbone",
150
+ "RBLNSwinBackboneConfig",
137
151
  "RBLNT5EncoderModel",
138
152
  "RBLNT5EncoderModelConfig",
139
153
  "RBLNT5ForConditionalGeneration",
@@ -169,7 +183,9 @@ if TYPE_CHECKING:
169
183
  RBLNAutoModelForSeq2SeqLM,
170
184
  RBLNAutoModelForSequenceClassification,
171
185
  RBLNAutoModelForSpeechSeq2Seq,
186
+ RBLNAutoModelForTextEncoding,
172
187
  RBLNAutoModelForVision2Seq,
188
+ RBLNAutoModelForZeroShotObjectDetection,
173
189
  RBLNBartForConditionalGeneration,
174
190
  RBLNBartForConditionalGenerationConfig,
175
191
  RBLNBartModel,
@@ -196,8 +212,12 @@ if TYPE_CHECKING:
196
212
  RBLNCLIPVisionModelWithProjectionConfig,
197
213
  RBLNColPaliForRetrieval,
198
214
  RBLNColPaliForRetrievalConfig,
215
+ RBLNDecoderOnlyModel,
216
+ RBLNDecoderOnlyModelConfig,
199
217
  RBLNDecoderOnlyModelForCausalLM,
200
218
  RBLNDecoderOnlyModelForCausalLMConfig,
219
+ RBLNDepthAnythingForDepthEstimation,
220
+ RBLNDepthAnythingForDepthEstimationConfig,
201
221
  RBLNDistilBertForQuestionAnswering,
202
222
  RBLNDistilBertForQuestionAnsweringConfig,
203
223
  RBLNDPTForDepthEstimation,
@@ -216,6 +236,12 @@ if TYPE_CHECKING:
216
236
  RBLNGPT2LMHeadModelConfig,
217
237
  RBLNGPT2Model,
218
238
  RBLNGPT2ModelConfig,
239
+ RBLNGroundingDinoDecoder,
240
+ RBLNGroundingDinoDecoderConfig,
241
+ RBLNGroundingDinoEncoder,
242
+ RBLNGroundingDinoEncoderConfig,
243
+ RBLNGroundingDinoForObjectDetection,
244
+ RBLNGroundingDinoForObjectDetectionConfig,
219
245
  RBLNIdefics3ForConditionalGeneration,
220
246
  RBLNIdefics3ForConditionalGenerationConfig,
221
247
  RBLNIdefics3VisionTransformer,
@@ -268,6 +294,8 @@ if TYPE_CHECKING:
268
294
  RBLNRobertaForSequenceClassificationConfig,
269
295
  RBLNSiglipVisionModel,
270
296
  RBLNSiglipVisionModelConfig,
297
+ RBLNSwinBackbone,
298
+ RBLNSwinBackboneConfig,
271
299
  RBLNT5EncoderModel,
272
300
  RBLNT5EncoderModelConfig,
273
301
  RBLNT5ForConditionalGeneration,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from typing import Any, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
@@ -25,7 +25,8 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs: Dict[str, Any],
28
+ model_input_shapes: Optional[List[Tuple[int, int]]] = None,
29
+ **kwargs: Any,
29
30
  ):
30
31
  """
31
32
  Args:
@@ -45,6 +46,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
45
46
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
47
 
47
48
  self.model_input_names = model_input_names or self.rbln_model_input_names
49
+ self.model_input_shapes = model_input_shapes
48
50
 
49
51
 
50
52
  class RBLNImageModelConfig(RBLNModelConfig):
@@ -52,7 +54,7 @@ class RBLNImageModelConfig(RBLNModelConfig):
52
54
  self,
53
55
  image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
56
  batch_size: Optional[int] = None,
55
- **kwargs: Dict[str, Any],
57
+ **kwargs: Any,
56
58
  ):
57
59
  """
58
60
  Args:
@@ -124,7 +126,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
124
126
  batch_size: Optional[int] = None,
125
127
  max_length: Optional[int] = None,
126
128
  num_mel_bins: Optional[int] = None,
127
- **kwargs: Dict[str, Any],
129
+ **kwargs: Any,
128
130
  ):
129
131
  """
130
132
  Args:
@@ -34,10 +34,7 @@ from transformers import (
34
34
  AutoModelForTextEncoding,
35
35
  PretrainedConfig,
36
36
  )
37
- from transformers.modeling_outputs import (
38
- BaseModelOutput,
39
- QuestionAnsweringModelOutput,
40
- )
37
+ from transformers.modeling_outputs import BaseModelOutput, QuestionAnsweringModelOutput
41
38
 
42
39
  from ..configuration_utils import RBLNCompileConfig
43
40
  from ..modeling import RBLNModel
@@ -130,10 +127,18 @@ class RBLNTransformerEncoder(RBLNModel):
130
127
  "This is an internal error. Please report it to the developers."
131
128
  )
132
129
 
133
- input_info = [
134
- (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
135
- for model_input_name in rbln_config.model_input_names
136
- ]
130
+ if rbln_config.model_input_shapes is None:
131
+ input_info = [
132
+ (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
133
+ for model_input_name in rbln_config.model_input_names
134
+ ]
135
+ else:
136
+ input_info = [
137
+ (model_input_name, model_input_shape, cls.rbln_dtype)
138
+ for model_input_name, model_input_shape in zip(
139
+ rbln_config.model_input_names, rbln_config.model_input_shapes
140
+ )
141
+ ]
137
142
 
138
143
  rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
139
144
  return rbln_config
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ModelOutput
20
+
21
+
22
+ @dataclass
23
+ class RBLNDecoderOnlyOutput(ModelOutput):
24
+ logits: torch.FloatTensor = None
25
+ generate_idx: torch.Tensor = None
26
+ padded_cache_lengths: int = None
27
+
28
+
29
+ @dataclass
30
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
31
+ attention_mask: Optional[torch.Tensor] = None
32
+
33
+
34
+ @dataclass
35
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
36
+ last_hidden_states: torch.FloatTensor = None
37
+ params: Tuple[torch.FloatTensor] = None
@@ -36,6 +36,8 @@ _import_structure = {
36
36
  "RBLNAutoModelForSpeechSeq2Seq",
37
37
  "RBLNAutoModelForVision2Seq",
38
38
  "RBLNAutoModelForImageTextToText",
39
+ "RBLNAutoModelForTextEncoding",
40
+ "RBLNAutoModelForZeroShotObjectDetection",
39
41
  ],
40
42
  "bart": [
41
43
  "RBLNBartForConditionalGeneration",
@@ -84,9 +86,12 @@ _import_structure = {
84
86
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
85
87
  ],
86
88
  "decoderonly": [
89
+ "RBLNDecoderOnlyModelConfig",
90
+ "RBLNDecoderOnlyModel",
87
91
  "RBLNDecoderOnlyModelForCausalLM",
88
92
  "RBLNDecoderOnlyModelForCausalLMConfig",
89
93
  ],
94
+ "depth_anything": ["RBLNDepthAnythingForDepthEstimationConfig", "RBLNDepthAnythingForDepthEstimation"],
90
95
  "dpt": [
91
96
  "RBLNDPTForDepthEstimation",
92
97
  "RBLNDPTForDepthEstimationConfig",
@@ -138,6 +143,10 @@ _import_structure = {
138
143
  "RBLNSiglipVisionModel",
139
144
  "RBLNSiglipVisionModelConfig",
140
145
  ],
146
+ "swin": [
147
+ "RBLNSwinBackbone",
148
+ "RBLNSwinBackboneConfig",
149
+ ],
141
150
  "time_series_transformer": [
142
151
  "RBLNTimeSeriesTransformerForPrediction",
143
152
  "RBLNTimeSeriesTransformerForPredictionConfig",
@@ -157,13 +166,18 @@ _import_structure = {
157
166
  "RBLNXLMRobertaForSequenceClassification",
158
167
  "RBLNXLMRobertaForSequenceClassificationConfig",
159
168
  ],
169
+ "grounding_dino": [
170
+ "RBLNGroundingDinoForObjectDetection",
171
+ "RBLNGroundingDinoForObjectDetectionConfig",
172
+ "RBLNGroundingDinoEncoder",
173
+ "RBLNGroundingDinoEncoderConfig",
174
+ "RBLNGroundingDinoDecoder",
175
+ "RBLNGroundingDinoDecoderConfig",
176
+ ],
160
177
  }
161
178
 
162
179
  if TYPE_CHECKING:
163
- from .audio_spectrogram_transformer import (
164
- RBLNASTForAudioClassification,
165
- RBLNASTForAudioClassificationConfig,
166
- )
180
+ from .audio_spectrogram_transformer import RBLNASTForAudioClassification, RBLNASTForAudioClassificationConfig
167
181
  from .auto import (
168
182
  RBLNAutoModel,
169
183
  RBLNAutoModelForAudioClassification,
@@ -177,7 +191,9 @@ if TYPE_CHECKING:
177
191
  RBLNAutoModelForSeq2SeqLM,
178
192
  RBLNAutoModelForSequenceClassification,
179
193
  RBLNAutoModelForSpeechSeq2Seq,
194
+ RBLNAutoModelForTextEncoding,
180
195
  RBLNAutoModelForVision2Seq,
196
+ RBLNAutoModelForZeroShotObjectDetection,
181
197
  )
182
198
  from .bart import (
183
199
  RBLNBartForConditionalGeneration,
@@ -211,22 +227,16 @@ if TYPE_CHECKING:
211
227
  RBLNCLIPVisionModelWithProjection,
212
228
  RBLNCLIPVisionModelWithProjectionConfig,
213
229
  )
214
- from .colpali import (
215
- RBLNColPaliForRetrieval,
216
- RBLNColPaliForRetrievalConfig,
217
- )
230
+ from .colpali import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
218
231
  from .decoderonly import (
232
+ RBLNDecoderOnlyModel,
233
+ RBLNDecoderOnlyModelConfig,
219
234
  RBLNDecoderOnlyModelForCausalLM,
220
235
  RBLNDecoderOnlyModelForCausalLMConfig,
221
236
  )
222
- from .distilbert import (
223
- RBLNDistilBertForQuestionAnswering,
224
- RBLNDistilBertForQuestionAnsweringConfig,
225
- )
226
- from .dpt import (
227
- RBLNDPTForDepthEstimation,
228
- RBLNDPTForDepthEstimationConfig,
229
- )
237
+ from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
238
+ from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
239
+ from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
230
240
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
231
241
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
232
242
  from .gemma3 import (
@@ -236,6 +246,14 @@ if TYPE_CHECKING:
236
246
  RBLNGemma3ForConditionalGenerationConfig,
237
247
  )
238
248
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
249
+ from .grounding_dino import (
250
+ RBLNGroundingDinoDecoder,
251
+ RBLNGroundingDinoDecoderConfig,
252
+ RBLNGroundingDinoEncoder,
253
+ RBLNGroundingDinoEncoderConfig,
254
+ RBLNGroundingDinoForObjectDetection,
255
+ RBLNGroundingDinoForObjectDetectionConfig,
256
+ )
239
257
  from .idefics3 import (
240
258
  RBLNIdefics3ForConditionalGeneration,
241
259
  RBLNIdefics3ForConditionalGenerationConfig,
@@ -272,6 +290,7 @@ if TYPE_CHECKING:
272
290
  RBLNRobertaForSequenceClassificationConfig,
273
291
  )
274
292
  from .siglip import RBLNSiglipVisionModel, RBLNSiglipVisionModelConfig
293
+ from .swin import RBLNSwinBackbone, RBLNSwinBackboneConfig
275
294
  from .t5 import (
276
295
  RBLNT5EncoderModel,
277
296
  RBLNT5EncoderModelConfig,
@@ -25,5 +25,7 @@ from .modeling_auto import (
25
25
  RBLNAutoModelForSeq2SeqLM,
26
26
  RBLNAutoModelForSequenceClassification,
27
27
  RBLNAutoModelForSpeechSeq2Seq,
28
+ RBLNAutoModelForTextEncoding,
28
29
  RBLNAutoModelForVision2Seq,
30
+ RBLNAutoModelForZeroShotObjectDetection,
29
31
  )
@@ -35,8 +35,12 @@ from transformers.models.auto.modeling_auto import (
35
35
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
36
36
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
37
37
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
38
+ MODEL_FOR_TEXT_ENCODING_MAPPING,
39
+ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
38
40
  MODEL_FOR_VISION_2_SEQ_MAPPING,
39
41
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
42
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
43
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
40
44
  MODEL_MAPPING,
41
45
  MODEL_MAPPING_NAMES,
42
46
  )
@@ -115,3 +119,13 @@ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
115
119
  class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
116
120
  _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
117
121
  _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
122
+
123
+
124
+ class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
125
+ _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
126
+ _model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
127
+
128
+
129
+ class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
130
+ _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
131
+ _model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
@@ -16,9 +16,7 @@ from typing import Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_attention_mask,
21
- )
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
22
20
  from transformers.utils import logging
23
21
 
24
22
  from ..seq2seq.seq2seq_architecture import (
@@ -32,3 +32,5 @@ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized BART models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = True
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+
4
+ class BertModelWrapper(torch.nn.Module):
5
+ def __init__(self, model, rbln_config):
6
+ super().__init__()
7
+ self.model = model
8
+ self.rbln_config = rbln_config
9
+
10
+ def forward(self, *args, **kwargs):
11
+ output = self.model(*args, **kwargs)
12
+ if isinstance(output, torch.Tensor):
13
+ return output
14
+ elif isinstance(output, tuple):
15
+ return tuple(x for x in output if x is not None)
16
+ return output
@@ -12,15 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ....utils.logging import get_logger
15
+ import torch
16
+
16
17
  from ...modeling_generic import (
17
18
  RBLNModelForMaskedLM,
18
19
  RBLNModelForQuestionAnswering,
19
20
  RBLNTransformerEncoderForFeatureExtraction,
20
21
  )
21
-
22
-
23
- logger = get_logger(__name__)
22
+ from .bert_architecture import BertModelWrapper
23
+ from .configuration_bert import RBLNBertModelConfig
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
@@ -34,6 +34,10 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
34
34
 
35
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
36
36
 
37
+ @classmethod
38
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
39
+ return BertModelWrapper(model, rbln_config)
40
+
37
41
 
38
42
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
43
  """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -62,7 +62,7 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
62
62
  vision_model: Optional[RBLNModelConfig] = None,
63
63
  qformer: Optional[RBLNModelConfig] = None,
64
64
  language_model: Optional[RBLNModelConfig] = None,
65
- **kwargs: Dict[str, Any],
65
+ **kwargs: Any,
66
66
  ):
67
67
  """
68
68
  Args:
@@ -35,11 +35,7 @@ from ....modeling import RBLNModel
35
35
  logger = logging.get_logger(__name__)
36
36
 
37
37
  if TYPE_CHECKING:
38
- from transformers import (
39
- AutoFeatureExtractor,
40
- AutoProcessor,
41
- AutoTokenizer,
42
- )
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
43
39
 
44
40
 
45
41
  class LoopProjector:
@@ -178,7 +174,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
178
174
  return Blip2QFormerModelWrapper(model).eval()
179
175
 
180
176
  @classmethod
181
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
177
+ def _update_submodule_config(
178
+ cls,
179
+ model: "PreTrainedModel",
180
+ rbln_config: RBLNModelConfig,
181
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
182
+ ):
182
183
  if rbln_config.num_query_tokens is None:
183
184
  rbln_config.num_query_tokens = model.config.num_query_tokens
184
185
 
@@ -12,13 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
- def __init__(self, batch_size: Optional[int] = None, **kwargs: Dict[str, Any]):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
22
22
  """
23
23
  Args:
24
24
  batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
@@ -50,7 +50,7 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
50
50
  interpolate_pos_encoding: Optional[bool] = None,
51
51
  output_hidden_states: Optional[bool] = None,
52
52
  output_attentions: Optional[bool] = None,
53
- **kwargs: Dict[str, Any],
53
+ **kwargs: Any,
54
54
  ):
55
55
  """
56
56
  Args: