optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +32 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +24 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +31 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +25 -251
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
59
59
  return torch.empty_like(q)
60
60
 
61
61
 
62
+ @torch.library.custom_op(
63
+ "rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
64
+ mutates_args=(["kcache", "vcache"]),
65
+ )
66
+ def paged_flash_attn_decode_kv_fp8(
67
+ q: Tensor,
68
+ k: Tensor,
69
+ v: Tensor,
70
+ mask: Tensor,
71
+ kcache: Tensor,
72
+ vcache: Tensor,
73
+ seq: Tensor,
74
+ scale: Tensor,
75
+ block_table: Tensor,
76
+ block_size: int,
77
+ partition: int,
78
+ k_scale: Tensor,
79
+ v_scale: Tensor,
80
+ ) -> Tensor:
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_decode_kv_fp8.register_fake
85
+ def paged_flash_attn_decode_kv_fp8_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ k_scale: Tensor,
98
+ v_scale: Tensor,
99
+ ) -> Tensor:
100
+ return torch.empty_like(q)
101
+
102
+
62
103
  @torch.library.custom_op(
63
104
  "rbln_custom_ops::paged_flash_attn_prefill",
64
105
  mutates_args=(["kcache", "vcache"]),
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
100
141
  return torch.empty_like(q)
101
142
 
102
143
 
144
+ @torch.library.custom_op(
145
+ "rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
146
+ mutates_args=(["kcache", "vcache"]),
147
+ )
148
+ def paged_flash_attn_prefill_kv_fp8(
149
+ q: Tensor,
150
+ k: Tensor,
151
+ v: Tensor,
152
+ mask: Tensor,
153
+ kcache: Tensor,
154
+ vcache: Tensor,
155
+ seq: Tensor,
156
+ scale: Tensor,
157
+ block_table: Tensor,
158
+ block_size: int,
159
+ partition: int,
160
+ k_scale: Tensor,
161
+ v_scale: Tensor,
162
+ ) -> Tensor:
163
+ return torch.empty_like(q)
164
+
165
+
166
+ @paged_flash_attn_prefill_kv_fp8.register_fake
167
+ def paged_flash_attn_prefill_kv_fp8_fake(
168
+ q: Tensor,
169
+ k: Tensor,
170
+ v: Tensor,
171
+ mask: Tensor,
172
+ kcache: Tensor,
173
+ vcache: Tensor,
174
+ seq: Tensor,
175
+ scale: Tensor,
176
+ block_table: Tensor,
177
+ block_size: int,
178
+ partition: int,
179
+ k_scale: Tensor,
180
+ v_scale: Tensor,
181
+ ) -> Tensor:
182
+ return torch.empty_like(q)
183
+
184
+
103
185
  @torch.library.custom_op(
104
186
  "rbln_custom_ops::paged_flash_causal_attn_decode",
105
187
  mutates_args=(["kcache", "vcache"]),
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
141
223
  return torch.empty_like(q)
142
224
 
143
225
 
226
+ @torch.library.custom_op(
227
+ "rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
228
+ mutates_args=(["kcache", "vcache"]),
229
+ )
230
+ def paged_flash_causal_attn_decode_kv_fp8(
231
+ q: Tensor,
232
+ k: Tensor,
233
+ v: Tensor,
234
+ kcache: Tensor,
235
+ vcache: Tensor,
236
+ seq: Tensor,
237
+ scale: Tensor,
238
+ block_table: Tensor,
239
+ block_size: int,
240
+ partition: int,
241
+ k_scale: Tensor,
242
+ v_scale: Tensor,
243
+ mask: Optional[Tensor] = None,
244
+ ) -> Tensor:
245
+ return torch.empty_like(q)
246
+
247
+
248
+ @paged_flash_causal_attn_decode_kv_fp8.register_fake
249
+ def paged_flash_causal_attn_decode_kv_fp8_fake(
250
+ q: Tensor,
251
+ k: Tensor,
252
+ v: Tensor,
253
+ kcache: Tensor,
254
+ vcache: Tensor,
255
+ seq: Tensor,
256
+ scale: Tensor,
257
+ block_table: Tensor,
258
+ block_size: int,
259
+ partition: int,
260
+ k_scale: Tensor,
261
+ v_scale: Tensor,
262
+ mask: Optional[Tensor] = None,
263
+ ) -> Tensor:
264
+ return torch.empty_like(q)
265
+
266
+
144
267
  @torch.library.custom_op(
145
268
  "rbln_custom_ops::paged_flash_causal_attn_prefill",
146
269
  mutates_args=(["kcache", "vcache"]),
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
182
305
  mask: Optional[Tensor] = None,
183
306
  ) -> Tensor:
184
307
  return torch.empty_like(q)
308
+
309
+
310
+ @torch.library.custom_op(
311
+ "rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
312
+ mutates_args=(["kcache", "vcache"]),
313
+ )
314
+ def paged_flash_causal_attn_prefill_kv_fp8(
315
+ q: Tensor,
316
+ k: Tensor,
317
+ v: Tensor,
318
+ kcache: Tensor,
319
+ vcache: Tensor,
320
+ seq: Tensor,
321
+ scale: Tensor,
322
+ block_table: Tensor,
323
+ block_size: int,
324
+ partition: int,
325
+ is_bidirectional: bool,
326
+ k_scale: Tensor,
327
+ v_scale: Tensor,
328
+ mask: Optional[Tensor] = None,
329
+ ) -> Tensor:
330
+ return torch.empty_like(q)
331
+
332
+
333
+ @paged_flash_causal_attn_prefill_kv_fp8.register_fake
334
+ def paged_flash_causal_attn_prefill_kv_fp8_fake(
335
+ q: Tensor,
336
+ k: Tensor,
337
+ v: Tensor,
338
+ kcache: Tensor,
339
+ vcache: Tensor,
340
+ seq: Tensor,
341
+ scale: Tensor,
342
+ block_table: Tensor,
343
+ block_size: int,
344
+ partition: int,
345
+ is_bidirectional: bool,
346
+ k_scale: Tensor,
347
+ v_scale: Tensor,
348
+ mask: Optional[Tensor] = None,
349
+ ) -> Tensor:
350
+ return torch.empty_like(q)
@@ -34,6 +34,8 @@ _import_structure = {
34
34
  "RBLNAutoModelForSequenceClassification",
35
35
  "RBLNAutoModelForSpeechSeq2Seq",
36
36
  "RBLNAutoModelForVision2Seq",
37
+ "RBLNAutoModelForTextEncoding",
38
+ "RBLNAutoModelForZeroShotObjectDetection",
37
39
  "RBLNBartForConditionalGeneration",
38
40
  "RBLNBartForConditionalGenerationConfig",
39
41
  "RBLNBartModel",
@@ -68,6 +70,8 @@ _import_structure = {
68
70
  "RBLNDistilBertForQuestionAnsweringConfig",
69
71
  "RBLNDPTForDepthEstimation",
70
72
  "RBLNDPTForDepthEstimationConfig",
73
+ "RBLNDepthAnythingForDepthEstimation",
74
+ "RBLNDepthAnythingForDepthEstimationConfig",
71
75
  "RBLNExaoneForCausalLM",
72
76
  "RBLNExaoneForCausalLMConfig",
73
77
  "RBLNGemmaModel",
@@ -82,6 +86,12 @@ _import_structure = {
82
86
  "RBLNGPT2LMHeadModelConfig",
83
87
  "RBLNGPT2Model",
84
88
  "RBLNGPT2ModelConfig",
89
+ "RBLNGroundingDinoDecoder",
90
+ "RBLNGroundingDinoDecoderConfig",
91
+ "RBLNGroundingDinoForObjectDetection",
92
+ "RBLNGroundingDinoForObjectDetectionConfig",
93
+ "RBLNGroundingDinoEncoder",
94
+ "RBLNGroundingDinoEncoderConfig",
85
95
  "RBLNIdefics3ForConditionalGeneration",
86
96
  "RBLNIdefics3ForConditionalGenerationConfig",
87
97
  "RBLNIdefics3VisionTransformer",
@@ -136,6 +146,8 @@ _import_structure = {
136
146
  "RBLNRobertaForSequenceClassificationConfig",
137
147
  "RBLNSiglipVisionModel",
138
148
  "RBLNSiglipVisionModelConfig",
149
+ "RBLNSwinBackbone",
150
+ "RBLNSwinBackboneConfig",
139
151
  "RBLNT5EncoderModel",
140
152
  "RBLNT5EncoderModelConfig",
141
153
  "RBLNT5ForConditionalGeneration",
@@ -171,7 +183,9 @@ if TYPE_CHECKING:
171
183
  RBLNAutoModelForSeq2SeqLM,
172
184
  RBLNAutoModelForSequenceClassification,
173
185
  RBLNAutoModelForSpeechSeq2Seq,
186
+ RBLNAutoModelForTextEncoding,
174
187
  RBLNAutoModelForVision2Seq,
188
+ RBLNAutoModelForZeroShotObjectDetection,
175
189
  RBLNBartForConditionalGeneration,
176
190
  RBLNBartForConditionalGenerationConfig,
177
191
  RBLNBartModel,
@@ -202,6 +216,8 @@ if TYPE_CHECKING:
202
216
  RBLNDecoderOnlyModelConfig,
203
217
  RBLNDecoderOnlyModelForCausalLM,
204
218
  RBLNDecoderOnlyModelForCausalLMConfig,
219
+ RBLNDepthAnythingForDepthEstimation,
220
+ RBLNDepthAnythingForDepthEstimationConfig,
205
221
  RBLNDistilBertForQuestionAnswering,
206
222
  RBLNDistilBertForQuestionAnsweringConfig,
207
223
  RBLNDPTForDepthEstimation,
@@ -220,6 +236,12 @@ if TYPE_CHECKING:
220
236
  RBLNGPT2LMHeadModelConfig,
221
237
  RBLNGPT2Model,
222
238
  RBLNGPT2ModelConfig,
239
+ RBLNGroundingDinoDecoder,
240
+ RBLNGroundingDinoDecoderConfig,
241
+ RBLNGroundingDinoEncoder,
242
+ RBLNGroundingDinoEncoderConfig,
243
+ RBLNGroundingDinoForObjectDetection,
244
+ RBLNGroundingDinoForObjectDetectionConfig,
223
245
  RBLNIdefics3ForConditionalGeneration,
224
246
  RBLNIdefics3ForConditionalGenerationConfig,
225
247
  RBLNIdefics3VisionTransformer,
@@ -272,6 +294,8 @@ if TYPE_CHECKING:
272
294
  RBLNRobertaForSequenceClassificationConfig,
273
295
  RBLNSiglipVisionModel,
274
296
  RBLNSiglipVisionModelConfig,
297
+ RBLNSwinBackbone,
298
+ RBLNSwinBackboneConfig,
275
299
  RBLNT5EncoderModel,
276
300
  RBLNT5EncoderModelConfig,
277
301
  RBLNT5ForConditionalGeneration,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from typing import Any, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
@@ -25,7 +25,8 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs: Dict[str, Any],
28
+ model_input_shapes: Optional[List[Tuple[int, int]]] = None,
29
+ **kwargs: Any,
29
30
  ):
30
31
  """
31
32
  Args:
@@ -45,6 +46,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
45
46
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
47
 
47
48
  self.model_input_names = model_input_names or self.rbln_model_input_names
49
+ self.model_input_shapes = model_input_shapes
48
50
 
49
51
 
50
52
  class RBLNImageModelConfig(RBLNModelConfig):
@@ -52,7 +54,7 @@ class RBLNImageModelConfig(RBLNModelConfig):
52
54
  self,
53
55
  image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
56
  batch_size: Optional[int] = None,
55
- **kwargs: Dict[str, Any],
57
+ **kwargs: Any,
56
58
  ):
57
59
  """
58
60
  Args:
@@ -124,7 +126,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
124
126
  batch_size: Optional[int] = None,
125
127
  max_length: Optional[int] = None,
126
128
  num_mel_bins: Optional[int] = None,
127
- **kwargs: Dict[str, Any],
129
+ **kwargs: Any,
128
130
  ):
129
131
  """
130
132
  Args:
@@ -34,10 +34,7 @@ from transformers import (
34
34
  AutoModelForTextEncoding,
35
35
  PretrainedConfig,
36
36
  )
37
- from transformers.modeling_outputs import (
38
- BaseModelOutput,
39
- QuestionAnsweringModelOutput,
40
- )
37
+ from transformers.modeling_outputs import BaseModelOutput, QuestionAnsweringModelOutput
41
38
 
42
39
  from ..configuration_utils import RBLNCompileConfig
43
40
  from ..modeling import RBLNModel
@@ -130,10 +127,18 @@ class RBLNTransformerEncoder(RBLNModel):
130
127
  "This is an internal error. Please report it to the developers."
131
128
  )
132
129
 
133
- input_info = [
134
- (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
135
- for model_input_name in rbln_config.model_input_names
136
- ]
130
+ if rbln_config.model_input_shapes is None:
131
+ input_info = [
132
+ (model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
133
+ for model_input_name in rbln_config.model_input_names
134
+ ]
135
+ else:
136
+ input_info = [
137
+ (model_input_name, model_input_shape, cls.rbln_dtype)
138
+ for model_input_name, model_input_shape in zip(
139
+ rbln_config.model_input_names, rbln_config.model_input_shapes
140
+ )
141
+ ]
137
142
 
138
143
  rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
139
144
  return rbln_config
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ModelOutput
20
+
21
+
22
+ @dataclass
23
+ class RBLNDecoderOnlyOutput(ModelOutput):
24
+ logits: torch.FloatTensor = None
25
+ generate_idx: torch.Tensor = None
26
+ padded_cache_lengths: int = None
27
+
28
+
29
+ @dataclass
30
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
31
+ attention_mask: Optional[torch.Tensor] = None
32
+
33
+
34
+ @dataclass
35
+ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
36
+ last_hidden_states: torch.FloatTensor = None
37
+ params: Tuple[torch.FloatTensor] = None
@@ -36,6 +36,8 @@ _import_structure = {
36
36
  "RBLNAutoModelForSpeechSeq2Seq",
37
37
  "RBLNAutoModelForVision2Seq",
38
38
  "RBLNAutoModelForImageTextToText",
39
+ "RBLNAutoModelForTextEncoding",
40
+ "RBLNAutoModelForZeroShotObjectDetection",
39
41
  ],
40
42
  "bart": [
41
43
  "RBLNBartForConditionalGeneration",
@@ -89,6 +91,7 @@ _import_structure = {
89
91
  "RBLNDecoderOnlyModelForCausalLM",
90
92
  "RBLNDecoderOnlyModelForCausalLMConfig",
91
93
  ],
94
+ "depth_anything": ["RBLNDepthAnythingForDepthEstimationConfig", "RBLNDepthAnythingForDepthEstimation"],
92
95
  "dpt": [
93
96
  "RBLNDPTForDepthEstimation",
94
97
  "RBLNDPTForDepthEstimationConfig",
@@ -140,6 +143,10 @@ _import_structure = {
140
143
  "RBLNSiglipVisionModel",
141
144
  "RBLNSiglipVisionModelConfig",
142
145
  ],
146
+ "swin": [
147
+ "RBLNSwinBackbone",
148
+ "RBLNSwinBackboneConfig",
149
+ ],
143
150
  "time_series_transformer": [
144
151
  "RBLNTimeSeriesTransformerForPrediction",
145
152
  "RBLNTimeSeriesTransformerForPredictionConfig",
@@ -159,13 +166,18 @@ _import_structure = {
159
166
  "RBLNXLMRobertaForSequenceClassification",
160
167
  "RBLNXLMRobertaForSequenceClassificationConfig",
161
168
  ],
169
+ "grounding_dino": [
170
+ "RBLNGroundingDinoForObjectDetection",
171
+ "RBLNGroundingDinoForObjectDetectionConfig",
172
+ "RBLNGroundingDinoEncoder",
173
+ "RBLNGroundingDinoEncoderConfig",
174
+ "RBLNGroundingDinoDecoder",
175
+ "RBLNGroundingDinoDecoderConfig",
176
+ ],
162
177
  }
163
178
 
164
179
  if TYPE_CHECKING:
165
- from .audio_spectrogram_transformer import (
166
- RBLNASTForAudioClassification,
167
- RBLNASTForAudioClassificationConfig,
168
- )
180
+ from .audio_spectrogram_transformer import RBLNASTForAudioClassification, RBLNASTForAudioClassificationConfig
169
181
  from .auto import (
170
182
  RBLNAutoModel,
171
183
  RBLNAutoModelForAudioClassification,
@@ -179,7 +191,9 @@ if TYPE_CHECKING:
179
191
  RBLNAutoModelForSeq2SeqLM,
180
192
  RBLNAutoModelForSequenceClassification,
181
193
  RBLNAutoModelForSpeechSeq2Seq,
194
+ RBLNAutoModelForTextEncoding,
182
195
  RBLNAutoModelForVision2Seq,
196
+ RBLNAutoModelForZeroShotObjectDetection,
183
197
  )
184
198
  from .bart import (
185
199
  RBLNBartForConditionalGeneration,
@@ -213,24 +227,16 @@ if TYPE_CHECKING:
213
227
  RBLNCLIPVisionModelWithProjection,
214
228
  RBLNCLIPVisionModelWithProjectionConfig,
215
229
  )
216
- from .colpali import (
217
- RBLNColPaliForRetrieval,
218
- RBLNColPaliForRetrievalConfig,
219
- )
230
+ from .colpali import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
220
231
  from .decoderonly import (
221
232
  RBLNDecoderOnlyModel,
222
233
  RBLNDecoderOnlyModelConfig,
223
234
  RBLNDecoderOnlyModelForCausalLM,
224
235
  RBLNDecoderOnlyModelForCausalLMConfig,
225
236
  )
226
- from .distilbert import (
227
- RBLNDistilBertForQuestionAnswering,
228
- RBLNDistilBertForQuestionAnsweringConfig,
229
- )
230
- from .dpt import (
231
- RBLNDPTForDepthEstimation,
232
- RBLNDPTForDepthEstimationConfig,
233
- )
237
+ from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
238
+ from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
239
+ from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
234
240
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
235
241
  from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
236
242
  from .gemma3 import (
@@ -240,6 +246,14 @@ if TYPE_CHECKING:
240
246
  RBLNGemma3ForConditionalGenerationConfig,
241
247
  )
242
248
  from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
249
+ from .grounding_dino import (
250
+ RBLNGroundingDinoDecoder,
251
+ RBLNGroundingDinoDecoderConfig,
252
+ RBLNGroundingDinoEncoder,
253
+ RBLNGroundingDinoEncoderConfig,
254
+ RBLNGroundingDinoForObjectDetection,
255
+ RBLNGroundingDinoForObjectDetectionConfig,
256
+ )
243
257
  from .idefics3 import (
244
258
  RBLNIdefics3ForConditionalGeneration,
245
259
  RBLNIdefics3ForConditionalGenerationConfig,
@@ -276,6 +290,7 @@ if TYPE_CHECKING:
276
290
  RBLNRobertaForSequenceClassificationConfig,
277
291
  )
278
292
  from .siglip import RBLNSiglipVisionModel, RBLNSiglipVisionModelConfig
293
+ from .swin import RBLNSwinBackbone, RBLNSwinBackboneConfig
279
294
  from .t5 import (
280
295
  RBLNT5EncoderModel,
281
296
  RBLNT5EncoderModelConfig,
@@ -25,5 +25,7 @@ from .modeling_auto import (
25
25
  RBLNAutoModelForSeq2SeqLM,
26
26
  RBLNAutoModelForSequenceClassification,
27
27
  RBLNAutoModelForSpeechSeq2Seq,
28
+ RBLNAutoModelForTextEncoding,
28
29
  RBLNAutoModelForVision2Seq,
30
+ RBLNAutoModelForZeroShotObjectDetection,
29
31
  )
@@ -35,8 +35,12 @@ from transformers.models.auto.modeling_auto import (
35
35
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
36
36
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
37
37
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
38
+ MODEL_FOR_TEXT_ENCODING_MAPPING,
39
+ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
38
40
  MODEL_FOR_VISION_2_SEQ_MAPPING,
39
41
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
42
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
43
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
40
44
  MODEL_MAPPING,
41
45
  MODEL_MAPPING_NAMES,
42
46
  )
@@ -115,3 +119,13 @@ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
115
119
  class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
116
120
  _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
117
121
  _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
122
+
123
+
124
+ class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
125
+ _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
126
+ _model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
127
+
128
+
129
+ class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
130
+ _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
131
+ _model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
@@ -16,9 +16,7 @@ from typing import Tuple
16
16
 
17
17
  import torch
18
18
  from torch import nn
19
- from transformers.modeling_attn_mask_utils import (
20
- _prepare_4d_attention_mask,
21
- )
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
22
20
  from transformers.utils import logging
23
21
 
24
22
  from ..seq2seq.seq2seq_architecture import (
@@ -32,3 +32,5 @@ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized BART models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = True
@@ -0,0 +1,16 @@
1
+ import torch
2
+
3
+
4
+ class BertModelWrapper(torch.nn.Module):
5
+ def __init__(self, model, rbln_config):
6
+ super().__init__()
7
+ self.model = model
8
+ self.rbln_config = rbln_config
9
+
10
+ def forward(self, *args, **kwargs):
11
+ output = self.model(*args, **kwargs)
12
+ if isinstance(output, torch.Tensor):
13
+ return output
14
+ elif isinstance(output, tuple):
15
+ return tuple(x for x in output if x is not None)
16
+ return output
@@ -12,15 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ....utils.logging import get_logger
15
+ import torch
16
+
16
17
  from ...modeling_generic import (
17
18
  RBLNModelForMaskedLM,
18
19
  RBLNModelForQuestionAnswering,
19
20
  RBLNTransformerEncoderForFeatureExtraction,
20
21
  )
21
-
22
-
23
- logger = get_logger(__name__)
22
+ from .bert_architecture import BertModelWrapper
23
+ from .configuration_bert import RBLNBertModelConfig
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
@@ -34,6 +34,10 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
34
34
 
35
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
36
36
 
37
+ @classmethod
38
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
39
+ return BertModelWrapper(model, rbln_config)
40
+
37
41
 
38
42
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
43
  """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -62,7 +62,7 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
62
62
  vision_model: Optional[RBLNModelConfig] = None,
63
63
  qformer: Optional[RBLNModelConfig] = None,
64
64
  language_model: Optional[RBLNModelConfig] = None,
65
- **kwargs: Dict[str, Any],
65
+ **kwargs: Any,
66
66
  ):
67
67
  """
68
68
  Args:
@@ -35,11 +35,7 @@ from ....modeling import RBLNModel
35
35
  logger = logging.get_logger(__name__)
36
36
 
37
37
  if TYPE_CHECKING:
38
- from transformers import (
39
- AutoFeatureExtractor,
40
- AutoProcessor,
41
- AutoTokenizer,
42
- )
38
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
43
39
 
44
40
 
45
41
  class LoopProjector:
@@ -178,7 +174,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
178
174
  return Blip2QFormerModelWrapper(model).eval()
179
175
 
180
176
  @classmethod
181
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
177
+ def _update_submodule_config(
178
+ cls,
179
+ model: "PreTrainedModel",
180
+ rbln_config: RBLNModelConfig,
181
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
182
+ ):
182
183
  if rbln_config.num_query_tokens is None:
183
184
  rbln_config.num_query_tokens = model.config.num_query_tokens
184
185
 
@@ -12,13 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
- def __init__(self, batch_size: Optional[int] = None, **kwargs: Dict[str, Any]):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
22
22
  """
23
23
  Args:
24
24
  batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
@@ -50,7 +50,7 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
50
50
  interpolate_pos_encoding: Optional[bool] = None,
51
51
  output_hidden_states: Optional[bool] = None,
52
52
  output_attentions: Optional[bool] = None,
53
- **kwargs: Dict[str, Any],
53
+ **kwargs: Any,
54
54
  ):
55
55
  """
56
56
  Args:
@@ -4,10 +4,7 @@ import torch
4
4
  from torch import nn
5
5
  from transformers import GemmaForCausalLM, GemmaModel
6
6
 
7
- from ..decoderonly.decoderonly_architecture import (
8
- RotaryEmbedding,
9
- apply_rotary_pos_emb,
10
- )
7
+ from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
11
8
 
12
9
 
13
10
  def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):