optimum-rbln 0.9.5a4__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +196 -52
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +5 -4
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_attention_utils.py +15 -9
  17. optimum/rbln/transformers/models/__init__.py +10 -0
  18. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  19. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  20. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  21. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +2 -2
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +26 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +2 -1
  24. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +45 -21
  25. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  26. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  27. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  28. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  29. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +4 -176
  30. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +4 -3
  31. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +10 -7
  32. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  33. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  34. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  35. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  36. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +7 -7
  37. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  38. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +2 -0
  39. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +2 -0
  40. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  41. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  42. optimum/rbln/utils/deprecation.py +78 -1
  43. optimum/rbln/utils/hub.py +93 -2
  44. optimum/rbln/utils/runtime_utils.py +2 -2
  45. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +1 -1
  46. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +49 -42
  47. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  48. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  49. {optimum_rbln-0.9.5a4.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -192,20 +192,24 @@ class RBLNDecoderOnlyFlashAttentionMixin:
192
192
  available_dram - without_dramtensor for without_dramtensor in alloc_per_node_without_dram
193
193
  ]
194
194
 
195
- kvcache_tensor_sizes: dict[str, list[int]] = compiled_models["prefill"].exp_get_dram_tensor_sizes()
195
+ # kvcache_tensor_sizes[key][node_id][chiplet_id] = alloc_size
196
+ kvcache_tensor_sizes: dict[str, list[list[int]]] = compiled_models["prefill"].exp_get_dram_tensor_sizes()
196
197
  kvcache_meta_can_resize: dict[str, bool] = {
197
198
  kvcache_meta.name: kvcache_meta.can_resize for kvcache_meta in rbln_config.kvcache_metas
198
199
  }
199
200
 
200
201
  def get_updated_kvcache_tensor_sizes(
201
- kvcache_tensor_sizes: dict[str, list[int]], multiplier: int
202
- ) -> dict[str, list[int]]:
202
+ kvcache_tensor_sizes: dict[str, list[list[int]]], multiplier: int
203
+ ) -> dict[str, list[list[int]]]:
203
204
  # Get the updated KV cache tensor sizes by multiplying the multiplier
204
205
  # with considering attention type (full or sliding), and memory alignment.
205
- ret = {}
206
- for key, sizes in kvcache_tensor_sizes.items():
206
+ ret: dict[str, list[list[int]]] = {}
207
+ for key, sizes_at_node in kvcache_tensor_sizes.items():
207
208
  m = multiplier if kvcache_meta_can_resize[key] else 1
208
- ret[key] = [align_2MB(size * m) for size in sizes]
209
+ ret[key] = [
210
+ [align_2MB(size_at_chiplet * m) for size_at_chiplet in sizes_at_node_at_chiplet]
211
+ for sizes_at_node_at_chiplet in sizes_at_node
212
+ ]
209
213
  return ret
210
214
 
211
215
  def check_memory_fits(multiplier: int) -> tuple[bool, list[int]]:
@@ -214,9 +218,11 @@ class RBLNDecoderOnlyFlashAttentionMixin:
214
218
  updated_kvcache_tensor_sizes = get_updated_kvcache_tensor_sizes(kvcache_tensor_sizes, multiplier)
215
219
 
216
220
  kvcache_tensor_sizes_at_node: list[int] = [0] * num_node
217
- for tensor_sizes in updated_kvcache_tensor_sizes.values():
218
- for node_id, size in enumerate(tensor_sizes):
219
- kvcache_tensor_sizes_at_node[node_id] += size
221
+ for tensor_sizes_at_node in updated_kvcache_tensor_sizes.values():
222
+ tensor_sizes_at_node: list[list[int]]
223
+ for node_id, sizes_at_chiplet in enumerate(tensor_sizes_at_node):
224
+ sizes_at_chiplet: list[int]
225
+ kvcache_tensor_sizes_at_node[node_id] += sum(sizes_at_chiplet)
220
226
 
221
227
  fits = all(
222
228
  remaining_dram_at_node[node_id] >= kvcache_tensor_sizes_at_node[node_id] for node_id in range(num_node)
@@ -79,6 +79,10 @@ _import_structure = {
79
79
  "RBLNColQwen2ForRetrieval",
80
80
  "RBLNColQwen2ForRetrievalConfig",
81
81
  ],
82
+ "detr": [
83
+ "RBLNDetrForObjectDetection",
84
+ "RBLNDetrForObjectDetectionConfig",
85
+ ],
82
86
  "distilbert": [
83
87
  "RBLNDistilBertForQuestionAnswering",
84
88
  "RBLNDistilBertForQuestionAnsweringConfig",
@@ -169,6 +173,10 @@ _import_structure = {
169
173
  "RBLNSiglipVisionModel",
170
174
  "RBLNSiglipVisionModelConfig",
171
175
  ],
176
+ "mixtral": [
177
+ "RBLNMixtralForCausalLM",
178
+ "RBLNMixtralForCausalLMConfig",
179
+ ],
172
180
  "swin": [
173
181
  "RBLNSwinBackbone",
174
182
  "RBLNSwinBackboneConfig",
@@ -264,6 +272,7 @@ if TYPE_CHECKING:
264
272
  RBLNLoRAConfig,
265
273
  )
266
274
  from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
275
+ from .detr import RBLNDetrForObjectDetection, RBLNDetrForObjectDetectionConfig
267
276
  from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
268
277
  from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
269
278
  from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
@@ -296,6 +305,7 @@ if TYPE_CHECKING:
296
305
  from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
297
306
  from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
298
307
  from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
308
+ from .mixtral import RBLNMixtralForCausalLM, RBLNMixtralForCausalLMConfig
299
309
  from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
300
310
  from .paligemma import (
301
311
  RBLNPaliGemmaForConditionalGeneration,
@@ -184,8 +184,8 @@ class _BaseAutoModelClass:
184
184
  model_id: Union[str, Path],
185
185
  export: bool = None,
186
186
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
187
- **kwargs,
188
- ):
187
+ **kwargs: Optional[Dict[str, Any]],
188
+ ) -> RBLNBaseModel:
189
189
  """
190
190
  Load an RBLN-accelerated model from a pretrained checkpoint or a compiled RBLN artifact.
191
191
 
@@ -213,7 +213,7 @@ class _BaseAutoModelClass:
213
213
  `token`, `trust_remote_code`, `cache_dir`, `subfolder`, `local_files_only`).
214
214
 
215
215
  Returns:
216
- An instantiated RBLN model ready for inference on RBLN NPUs.
216
+ RBLNBaseModel: An instantiated RBLN model ready for inference on RBLN NPUs.
217
217
  """
218
218
  rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
219
219
  return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
@@ -32,8 +32,13 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
32
32
  def __init__(
33
33
  self,
34
34
  batch_size: Optional[int] = None,
35
- **kwargs,
35
+ **kwargs: Any,
36
36
  ):
37
+ """
38
+ Args:
39
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
40
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
41
+ """
37
42
  super().__init__(**kwargs)
38
43
  self.batch_size = batch_size or 1
39
44
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
@@ -53,7 +58,7 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
53
58
  batch_size: Optional[int] = None,
54
59
  num_query_tokens: Optional[int] = None,
55
60
  image_text_hidden_size: Optional[int] = None,
56
- **kwargs,
61
+ **kwargs: Any,
57
62
  ):
58
63
  """
59
64
  Args:
@@ -468,7 +468,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
468
468
  input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
469
469
  attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
470
470
  inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
471
- interpolate_pos_encoding (bool, optional, defaults to False) Whether to interpolate the positional encoding of the image embeddings.
471
+ interpolate_pos_encoding (bool, optional, defaults to False): Whether to interpolate the positional encoding of the image embeddings.
472
472
  Returns:
473
473
  A list of strings of length batch_size * num_captions.
474
474
  """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from optimum.rbln.configuration_utils import RBLNModelConfig
18
18
 
@@ -61,7 +61,7 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
61
61
  batch_size: Optional[int] = None,
62
62
  output_hidden_states: Optional[bool] = None,
63
63
  vlm: Optional[RBLNModelConfig] = None,
64
- **kwargs,
64
+ **kwargs: Any,
65
65
  ):
66
66
  """
67
67
  Args:
@@ -61,7 +61,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
61
61
  logits_to_keep: Optional[int] = None,
62
62
  output_hidden_states: Optional[bool] = None,
63
63
  kvcache_metas: Optional[List["KVCacheMeta"]] = None,
64
- **kwargs,
64
+ **kwargs: Any,
65
65
  ):
66
66
  """
67
67
  Args:
@@ -288,6 +288,31 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
288
288
  def can_generate(self) -> bool:
289
289
  return "decode" in self.phases
290
290
 
291
+ @property
292
+ def use_image_prefill(self):
293
+ return "image_prefill" in self.phases
294
+
295
+ @property
296
+ def image_prefill_runtime_idx(self):
297
+ return self.phases.index("image_prefill")
298
+
299
+ @property
300
+ def expected_compiled_model_names(self):
301
+ # ["prefill", "image_prefill", "decoder_batch_1", "decoder_batch_2", ...]
302
+ if self.can_generate:
303
+ return self.phases[: self.decoder_runtime_idx] + [
304
+ f"decoder_batch_{batch_size}" for batch_size in self.decoder_batch_sizes
305
+ ]
306
+ else:
307
+ return self.phases
308
+
309
+ @property
310
+ def decoder_runtime_idx(self):
311
+ if self.can_generate:
312
+ return self.phases.index("decode")
313
+ else:
314
+ raise ValueError("`decode` phase is not in the phases.")
315
+
291
316
  @property
292
317
  def nbits_per_param(self) -> int:
293
318
  if self.quantization:
@@ -539,6 +539,7 @@ class DecoderOnlyLayer(nn.Module):
539
539
  _POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
540
540
  _PRE_FF_LAYERNORM_ATTRS = None
541
541
  _POST_FF_LAYERNORM_ATTRS = None
542
+ _MLP_ATTR = ("mlp",)
542
543
 
543
544
  def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
544
545
  super().__init__()
@@ -547,7 +548,7 @@ class DecoderOnlyLayer(nn.Module):
547
548
  self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
548
549
  self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
549
550
  self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
550
- self.mlp = layer.mlp
551
+ self.mlp = _get_attr_from_candidates(layer, self._MLP_ATTR)
551
552
  self.self_attn = self_attn
552
553
  self._phase = "prefill"
553
554
  self.lora_config = lora_config
@@ -104,6 +104,11 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
104
104
  "rbln_config": self.rbln_config,
105
105
  "config": self.config,
106
106
  }
107
+
108
+ if self.rbln_config.use_image_prefill:
109
+ # TODO(sdk-gen): Implement and combine prefill and image prefill into a single runtime.
110
+ raise NotImplementedError(f"Image prefill at {self.__class__.__name__} is not supported yet.")
111
+
107
112
  self.prefill_decoder = RBLNRuntimeModel(
108
113
  runtime=self.model[0],
109
114
  phase="prefill",
@@ -287,9 +292,27 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
287
292
  phase="prefill",
288
293
  )
289
294
 
295
+ if rbln_config.use_image_prefill:
296
+ image_prefill_compile_config = rbln_config.compile_cfgs[rbln_config.image_prefill_runtime_idx]
297
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
298
+ fill=0, static_tensors=static_tensors
299
+ )
300
+ compiled_image_prefill = cls._compile_model(
301
+ wrapped_model,
302
+ image_prefill_compile_config,
303
+ image_prefill_example_inputs,
304
+ context,
305
+ rbln_config,
306
+ rbln_config.quantization,
307
+ phase="image_prefill",
308
+ )
309
+ compiled_models["image_prefill"] = compiled_image_prefill
310
+
290
311
  if rbln_config.can_generate:
291
312
  wrapped_model.phase = "decode"
292
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
313
+ for batch_size, dec_compile_config in zip(
314
+ rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[rbln_config.decoder_runtime_idx :]
315
+ ):
293
316
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
294
317
  compiled_decoder = cls._compile_model(
295
318
  wrapped_model,
@@ -548,6 +571,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
548
571
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
549
572
  compile_cfgs = [prefill_compile_config]
550
573
 
574
+ if rbln_config.use_image_prefill:
575
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
576
+ raise NotImplementedError(
577
+ "Not implemented for different prefill chunk sizes between text and image prefill."
578
+ )
579
+ image_prefill_input_info = cls.get_input_info(
580
+ batch_size=1,
581
+ query_length=rbln_config.image_prefill_chunk_size,
582
+ rbln_config=rbln_config,
583
+ model_config=model_config,
584
+ )
585
+ image_prefill_compile_config = RBLNCompileConfig(
586
+ compiled_model_name="image_prefill", input_info=image_prefill_input_info
587
+ )
588
+ compile_cfgs.append(image_prefill_compile_config)
589
+
551
590
  if rbln_config.can_generate:
552
591
  for batch_size in rbln_config.decoder_batch_sizes:
553
592
  dec_input_info = cls.get_input_info(
@@ -569,36 +608,21 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
569
608
  compiled_models: List[rebel.RBLNCompiledModel],
570
609
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
571
610
  ) -> List[rebel.Runtime]:
572
- expected_model_names = ["prefill"]
573
- if rbln_config.can_generate:
574
- expected_model_names.extend(
575
- [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
576
- )
611
+ expected_model_names = rbln_config.expected_compiled_model_names
612
+
577
613
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
578
614
  cls._raise_missing_compiled_file_error(expected_model_names)
579
615
 
580
616
  ret_val = [
581
617
  rebel.Runtime(
582
- compiled_models[0],
618
+ compiled_models[i],
583
619
  tensor_type="pt",
584
- device=rbln_config.device_map["prefill"],
620
+ device=rbln_config.device_map[model_name],
585
621
  activate_profiler=rbln_config.activate_profiler,
586
622
  timeout=rbln_config.timeout,
587
623
  )
624
+ for i, model_name in enumerate(expected_model_names)
588
625
  ]
589
- if rbln_config.can_generate:
590
- ret_val.extend(
591
- [
592
- rebel.Runtime(
593
- compiled_models[i + 1],
594
- tensor_type="pt",
595
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
596
- activate_profiler=rbln_config.activate_profiler,
597
- timeout=rbln_config.timeout,
598
- )
599
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
600
- ]
601
- )
602
626
  return ret_val
603
627
 
604
628
  def forward(
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .configuration_detr import RBLNDetrForObjectDetectionConfig
17
+ from .modeling_detr import RBLNDetrForObjectDetection
18
+
19
+
20
+ __all__ = [
21
+ "RBLNDetrForObjectDetectionConfig",
22
+ "RBLNDetrForObjectDetection",
23
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
17
+
18
+
19
+ class RBLNDetrForObjectDetectionConfig(RBLNModelForImageClassificationConfig):
20
+ """
21
+ Configuration class for RBLNDetrForObjectDetection.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized DETR models for object detection tasks.
25
+ """
26
+
27
+ def __init__(self, **kwargs):
28
+ """
29
+ Args:
30
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
31
+ Can be an integer for square images or a tuple (height, width).
32
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
33
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
34
+
35
+ Raises:
36
+ ValueError: If batch_size is not a positive integer.
37
+ """
38
+ super().__init__(**kwargs)
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING, Tuple, Union
17
+
18
+ import torch
19
+ from transformers.models.detr.modeling_detr import DetrObjectDetectionOutput
20
+
21
+ from ...modeling_generic import RBLNModelForImageClassification
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ pass
26
+
27
+
28
+ class RBLNDetrForObjectDetection(RBLNModelForImageClassification):
29
+ """
30
+ RBLN optimized DETR model for object detection tasks.
31
+
32
+ This class provides hardware-accelerated inference for DETR models
33
+ on RBLN devices, supporting object detection with detection heads
34
+ designed for object detection tasks.
35
+ """
36
+
37
+ def forward(
38
+ self, pixel_values: torch.Tensor, return_dict: bool = None, **kwargs
39
+ ) -> Union[Tuple, DetrObjectDetectionOutput]:
40
+ """
41
+ Foward pass for the RBLN-optimized DETR model for object detection.
42
+
43
+ Args:
44
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
45
+ return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
46
+
47
+ Returns:
48
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
49
+ """
50
+ output = self.model[0](pixel_values=pixel_values, **kwargs)
51
+ return DetrObjectDetectionOutput(
52
+ logits=output[0], pred_boxes=output[1], last_hidden_state=output[2], encoder_last_hidden_state=output[3]
53
+ )
@@ -58,13 +58,8 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
58
58
  )
59
59
  self.image_prefill_chunk_size = image_prefill_chunk_size
60
60
 
61
- @property
62
- def use_image_prefill(self):
63
- return self.image_prefill_chunk_size is not None
64
-
65
- @property
66
- def decoder_runtime_idx(self):
67
- return 2 if self.use_image_prefill else 1
61
+ if not (self.use_attention_mask and self.use_position_ids):
62
+ raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
68
63
 
69
64
 
70
65
  class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -13,11 +13,9 @@
13
13
  # limitations under the License.
14
14
  import importlib
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
17
17
 
18
- import rebel
19
18
  import torch
20
- from rebel.compile_context import CompileContext
21
19
  from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
22
20
  from transformers.modeling_outputs import BaseModelOutputWithPooling
23
21
  from transformers.modeling_utils import no_init_weights
@@ -29,10 +27,7 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
27
  from ...utils.rbln_runtime_wrapper import LoopProcessor
30
28
  from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
29
  from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
- from ..decoderonly.modeling_decoderonly import (
33
- RBLNDecoderOnlyModelForCausalLM,
34
- )
35
- from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
30
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
36
31
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
32
  from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
38
33
 
@@ -455,174 +450,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
455
450
  f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
456
451
  )
457
452
 
458
- return rbln_config
459
-
460
- @classmethod
461
- def _update_rbln_config(
462
- cls,
463
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
464
- model: Optional["PreTrainedModel"] = None,
465
- model_config: Optional["PretrainedConfig"] = None,
466
- rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
467
- ) -> RBLNGemma3ForCausalLMConfig:
468
- # Update rbln_config with super class
469
- rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
470
-
471
- if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
472
- raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
473
-
474
- if rbln_config.use_image_prefill:
475
- if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
476
- raise NotImplementedError(
477
- "Not implemented for different prefill chunk sizes between text and image prefill."
478
- )
479
-
480
- # Update image prefill compile config
481
- img_prefill_input_info = cls.get_input_info(
482
- batch_size=1,
483
- query_length=rbln_config.image_prefill_chunk_size,
484
- rbln_config=rbln_config,
485
- model_config=model_config,
486
- )
487
- image_prefill_compile_config = RBLNCompileConfig(
488
- compiled_model_name="image_prefill", input_info=img_prefill_input_info
489
- )
490
- # Insert image_prefill compile config at index 1
491
- compile_cfgs = rbln_config.compile_cfgs
492
- compile_cfgs.insert(1, image_prefill_compile_config)
493
- rbln_config.set_compile_cfgs(compile_cfgs)
453
+ if "image_prefill" not in rbln_config.phases:
454
+ rbln_config.phases = ["prefill", "image_prefill", "decode"]
494
455
 
495
456
  return rbln_config
496
-
497
- @classmethod
498
- @torch.inference_mode()
499
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
500
- wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
501
-
502
- rbln_compile_configs = rbln_config.compile_cfgs
503
- prefill_compile_config = rbln_compile_configs[0]
504
-
505
- context = CompileContext(use_weight_sharing=True)
506
-
507
- # Here we use meta tensor, for the memory efficiency.
508
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
509
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
510
-
511
- # Mark static tensors (self kv states)
512
- static_tensors = {}
513
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
514
- if "past_key_values" in name:
515
- static_tensors[name] = tensor
516
- context.mark_static_address(tensor)
517
-
518
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
519
- try:
520
- if quantization:
521
- quantization.maybe_set_quantization_env()
522
- original_linear = torch.nn.functional.linear
523
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
524
- compiled_model = cls.compile(
525
- wrapped_model,
526
- compile_config,
527
- create_runtimes=rbln_config.create_runtimes,
528
- device=rbln_config.device,
529
- example_inputs=example_inputs,
530
- compile_context=compile_context,
531
- )
532
- return compiled_model
533
- finally:
534
- torch.nn.functional.linear = original_linear
535
- if quantization:
536
- quantization.maybe_reset_quantization_env()
537
-
538
- wrapped_model.phase = "prefill"
539
- compiled_prefill = compile_model(
540
- wrapped_model,
541
- prefill_compile_config,
542
- prefill_example_inputs,
543
- context,
544
- rbln_config.quantization,
545
- )
546
- compiled_models = {"prefill": compiled_prefill}
547
-
548
- if rbln_config.use_image_prefill:
549
- image_prefill_compile_config = rbln_compile_configs[1]
550
- image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
551
- fill=0, static_tensors=static_tensors
552
- )
553
- wrapped_model.phase = "image_prefill"
554
- compiled_image_prefill = compile_model(
555
- wrapped_model,
556
- image_prefill_compile_config,
557
- image_prefill_example_inputs,
558
- context,
559
- rbln_config.quantization,
560
- )
561
- compiled_models["image_prefill"] = compiled_image_prefill
562
-
563
- wrapped_model.phase = "decode"
564
- for batch_size, dec_compile_config in zip(
565
- rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
566
- ):
567
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
568
- compiled_decoder = compile_model(
569
- wrapped_model,
570
- dec_compile_config,
571
- dec_example_inputs,
572
- context,
573
- rbln_config.quantization,
574
- )
575
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
576
-
577
- return compiled_models
578
-
579
- @classmethod
580
- def _create_runtimes(
581
- cls,
582
- compiled_models: List[rebel.RBLNCompiledModel],
583
- rbln_config: RBLNGemma3ForCausalLMConfig,
584
- ) -> List[rebel.Runtime]:
585
- expected_model_names = [
586
- "prefill",
587
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
588
- ]
589
- if rbln_config.use_image_prefill:
590
- expected_model_names.insert(1, "image_prefill")
591
-
592
- if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
593
- cls._raise_missing_compiled_file_error(expected_model_names)
594
-
595
- ret_val = [
596
- rebel.Runtime(
597
- compiled_models[0],
598
- tensor_type="pt",
599
- device=rbln_config.device_map["prefill"],
600
- activate_profiler=rbln_config.activate_profiler,
601
- timeout=rbln_config.timeout,
602
- )
603
- ]
604
- if rbln_config.use_image_prefill:
605
- ret_val.append(
606
- rebel.Runtime(
607
- compiled_models[1],
608
- tensor_type="pt",
609
- device=rbln_config.device_map["image_prefill"],
610
- activate_profiler=rbln_config.activate_profiler,
611
- timeout=rbln_config.timeout,
612
- ),
613
- )
614
-
615
- ret_val.extend(
616
- [
617
- rebel.Runtime(
618
- compiled_models[i + rbln_config.decoder_runtime_idx],
619
- tensor_type="pt",
620
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
621
- activate_profiler=rbln_config.activate_profiler,
622
- timeout=rbln_config.timeout,
623
- )
624
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
625
- ]
626
- )
627
-
628
- return ret_val