optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +230 -67
  4. optimum/rbln/diffusers/models/controlnet.py +2 -2
  5. optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
  6. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
  7. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
  8. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
  12. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  13. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
  14. optimum/rbln/modeling_base.py +11 -10
  15. optimum/rbln/ops/__init__.py +1 -0
  16. optimum/rbln/ops/attn.py +10 -0
  17. optimum/rbln/ops/flash_attn.py +8 -0
  18. optimum/rbln/ops/moe.py +180 -0
  19. optimum/rbln/ops/sliding_window_attn.py +9 -0
  20. optimum/rbln/transformers/__init__.py +44 -0
  21. optimum/rbln/transformers/modeling_attention_utils.py +124 -222
  22. optimum/rbln/transformers/modeling_outputs.py +25 -0
  23. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  24. optimum/rbln/transformers/models/__init__.py +38 -0
  25. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  27. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
  28. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
  29. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  30. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  31. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  32. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
  33. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  34. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  35. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
  36. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  37. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
  38. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  39. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
  40. optimum/rbln/transformers/models/detr/__init__.py +23 -0
  41. optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
  42. optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
  43. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  44. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  45. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  46. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  47. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  48. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  49. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
  50. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  51. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  53. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  54. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
  55. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  56. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
  57. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  58. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  59. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  60. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  61. optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
  62. optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
  63. optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
  64. optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
  65. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  66. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  67. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  68. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  69. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  70. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  71. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  76. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  77. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  78. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  79. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  80. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  81. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
  82. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  83. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  85. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  86. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  87. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  88. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  89. optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
  90. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  91. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  92. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  94. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
  96. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  97. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  98. optimum/rbln/utils/deprecation.py +78 -1
  99. optimum/rbln/utils/hub.py +93 -2
  100. optimum/rbln/utils/import_utils.py +16 -1
  101. optimum/rbln/utils/runtime_utils.py +12 -8
  102. optimum/rbln/utils/submodule.py +24 -0
  103. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
  104. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
  105. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  106. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
  107. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
  108. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
@@ -177,7 +177,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
177
177
  dec_attn_mask: torch.Tensor,
178
178
  page_table_manager: RBLNPageTableManager,
179
179
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
180
- config: "PreTrainedConfig" = None,
180
+ config: Optional["PreTrainedConfig"] = None,
181
181
  logits_last_dim: Optional[int] = None,
182
182
  **kwargs: Any,
183
183
  ) -> None:
@@ -391,16 +391,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
391
391
  # Initialize attention mask for chunked processing
392
392
  if self.rbln_config.use_attention_mask:
393
393
  if self.rbln_config.use_position_ids:
394
- chunked_attention_mask = torch.zeros(
395
- 1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
396
- )
394
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=self.rbln_config.dtype)
397
395
  else:
398
396
  chunked_attention_mask = torch.zeros(
399
397
  1,
400
398
  1,
401
399
  self.rbln_config.prefill_chunk_size,
402
400
  self.rbln_config.max_seq_len,
403
- dtype=self.rbln_config.torch_dtype,
401
+ dtype=self.rbln_config.dtype,
404
402
  )
405
403
  else:
406
404
  chunked_attention_mask = None
@@ -467,7 +465,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
467
465
  1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
468
466
  logits_last_dim,
469
467
  )
470
- output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
468
+ output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
471
469
 
472
470
  if self.rbln_config.logits_to_keep == 1:
473
471
  for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
@@ -486,7 +484,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
486
484
  self.config.hidden_size,
487
485
  )
488
486
  output_hidden_states = [
489
- torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
487
+ torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
490
488
  for _ in range(self.config.num_hidden_layers + 1)
491
489
  ]
492
490
 
@@ -26,15 +26,16 @@ from transformers.modeling_utils import no_init_weights
26
26
  from ....configuration_utils import RBLNCompileConfig
27
27
  from ....modeling import RBLNModel
28
28
  from ....utils.logging import get_logger
29
+ from ....utils.runtime_utils import is_compiler_supports_buffer_resize
29
30
  from ...modeling_attention_utils import (
30
31
  RBLNDecoderOnlyFlashAttentionMixin,
31
32
  set_default_values,
32
33
  validate_attention_method,
33
34
  validate_sliding_window,
34
35
  )
35
- from ...modeling_outputs import RBLNDecoderOnlyOutput
36
+ from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
36
37
  from ...utils.rbln_quantization import get_quantized_model
37
- from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
+ from .configuration_decoderonly import KVCacheMeta, RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
39
  from .decoderonly_architecture import DecoderOnlyWrapper
39
40
  from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
41
  from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
@@ -103,6 +104,11 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
103
104
  "rbln_config": self.rbln_config,
104
105
  "config": self.config,
105
106
  }
107
+
108
+ if self.rbln_config.use_image_prefill:
109
+ # TODO(sdk-gen): Implement and combine prefill and image prefill into a single runtime.
110
+ raise NotImplementedError(f"Image prefill at {self.__class__.__name__} is not supported yet.")
111
+
106
112
  self.prefill_decoder = RBLNRuntimeModel(
107
113
  runtime=self.model[0],
108
114
  phase="prefill",
@@ -230,7 +236,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
230
236
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
231
237
  quantization=None,
232
238
  phase: str = "prefill",
233
- ):
239
+ ) -> rebel.RBLNCompiledModel:
234
240
  try:
235
241
  wrapped_model.phase = phase
236
242
  if quantization:
@@ -252,21 +258,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
252
258
  quantization.maybe_reset_quantization_env()
253
259
 
254
260
  @classmethod
255
- def _get_compile_context(
256
- cls,
257
- compile_config: RBLNCompileConfig,
258
- example_inputs: List[torch.Tensor],
259
- ):
261
+ def _get_compile_context(cls, compile_config: RBLNCompileConfig, example_inputs: List[torch.Tensor]):
260
262
  context = CompileContext(use_weight_sharing=True)
261
263
 
262
264
  # Mark static tensors (self kv states)
263
265
  static_tensors = {}
264
- idx = 0
265
266
  for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
266
267
  if "past_key_values" in name:
267
268
  static_tensors[name] = tensor
268
- context.mark_static_address(tensor, f"kv_cache_{idx}")
269
- idx += 1
269
+ context.mark_static_address(tensor, name)
270
270
 
271
271
  return context, static_tensors
272
272
 
@@ -281,7 +281,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
281
281
  prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
282
282
  context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
283
283
 
284
- compiled_models = {}
284
+ compiled_models: dict[str, rebel.RBLNCompiledModel] = {}
285
285
  compiled_models["prefill"] = cls._compile_model(
286
286
  wrapped_model,
287
287
  prefill_compile_config,
@@ -292,9 +292,27 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
292
292
  phase="prefill",
293
293
  )
294
294
 
295
+ if rbln_config.use_image_prefill:
296
+ image_prefill_compile_config = rbln_config.compile_cfgs[rbln_config.image_prefill_runtime_idx]
297
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
298
+ fill=0, static_tensors=static_tensors
299
+ )
300
+ compiled_image_prefill = cls._compile_model(
301
+ wrapped_model,
302
+ image_prefill_compile_config,
303
+ image_prefill_example_inputs,
304
+ context,
305
+ rbln_config,
306
+ rbln_config.quantization,
307
+ phase="image_prefill",
308
+ )
309
+ compiled_models["image_prefill"] = compiled_image_prefill
310
+
295
311
  if rbln_config.can_generate:
296
312
  wrapped_model.phase = "decode"
297
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
313
+ for batch_size, dec_compile_config in zip(
314
+ rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[rbln_config.decoder_runtime_idx :]
315
+ ):
298
316
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
299
317
  compiled_decoder = cls._compile_model(
300
318
  wrapped_model,
@@ -307,14 +325,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
307
325
  )
308
326
  compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
309
327
 
310
- # check if the memory is enough to have additional blocks
311
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
312
- if rbln_config.kvcache_num_blocks < required_num_blocks:
313
- cls.maybe_suggest_kvcache_num_blocks(
314
- compiled_models=compiled_models,
315
- model_config=model.config,
316
- rbln_config=rbln_config,
317
- )
328
+ if rbln_config.is_auto_num_blocks:
329
+ if not is_compiler_supports_buffer_resize():
330
+ raise RuntimeError("`kvcache_num_blocks` must be set.")
331
+ cls.set_kvcache_num_blocks_after_compilation(compiled_models, rbln_config)
318
332
 
319
333
  return compiled_models
320
334
 
@@ -330,8 +344,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
330
344
  return model
331
345
 
332
346
  @classmethod
333
- def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
334
- return use_local_attention
347
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True, logits_to_keep: int = None):
348
+ return is_prefill and (use_local_attention or logits_to_keep == 1)
335
349
 
336
350
  @classmethod
337
351
  def get_input_info(
@@ -350,7 +364,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
350
364
 
351
365
  input_info = []
352
366
  if rbln_config.use_inputs_embeds:
353
- input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
367
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.dtype))
354
368
  else:
355
369
  input_info.append(("input_ids", [batch_size, query_length], "int64"))
356
370
 
@@ -364,15 +378,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
364
378
  if rbln_config.use_local_attention:
365
379
  input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
366
380
 
367
- if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
381
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill, rbln_config.logits_to_keep):
368
382
  input_info.append(("query_position", [], "int16"))
369
383
 
370
384
  if rbln_config.use_attention_mask:
371
385
  if rbln_config.use_position_ids:
372
- input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
386
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.dtype))
373
387
  else:
374
388
  input_info.append(
375
- ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
389
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.dtype)
376
390
  )
377
391
 
378
392
  if rbln_config.use_position_ids:
@@ -381,29 +395,36 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
381
395
  if rbln_config.use_lora:
382
396
  input_info.append(("lora_int_ids", [batch_size], "int32"))
383
397
 
384
- kvcache_dtype = rbln_config.torch_dtype
385
- if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
386
- kvcache_dtype = "float8_e4m3fn"
398
+ if len(rbln_config.kvcache_metas) > 0:
399
+ # Meta is already set, use it
400
+ input_info.extend(
401
+ [
402
+ (kvcache_meta.name, kvcache_meta.compile_shape, kvcache_meta.dtype)
403
+ for kvcache_meta in rbln_config.kvcache_metas
404
+ ]
405
+ )
387
406
 
388
- global_kvcache_shape = [
389
- rbln_config.kvcache_num_blocks,
390
- num_key_value_heads,
391
- rbln_config.kvcache_block_size,
392
- head_dim,
393
- ]
394
- local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
395
- input_info.extend(
396
- [
397
- (
398
- f"past_key_values_{i}",
399
- local_kvcache_shape
400
- if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
401
- else global_kvcache_shape,
402
- kvcache_dtype,
407
+ else:
408
+ kvcache_dtype = rbln_config.dtype
409
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
410
+ kvcache_dtype = "float8_e4m3fn"
411
+
412
+ kvcache_metas = []
413
+ for i in range(num_hidden_layers * 2):
414
+ layer_idx = i // 2
415
+ name = f"past_key_values_{i}"
416
+ kvcache_meta = KVCacheMeta.make(
417
+ name,
418
+ layer_idx,
419
+ num_key_value_heads,
420
+ head_dim,
421
+ RBLNCompileConfig.normalize_dtype(kvcache_dtype),
422
+ rbln_config,
403
423
  )
404
- for i in range(num_hidden_layers * 2)
405
- ]
406
- )
424
+ kvcache_metas.append(kvcache_meta)
425
+ input_info.append((name, kvcache_meta.compile_shape, kvcache_meta.dtype))
426
+
427
+ rbln_config.kvcache_metas.extend(kvcache_metas)
407
428
 
408
429
  return input_info
409
430
 
@@ -475,51 +496,39 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
475
496
  max_seq_len=rbln_config.max_seq_len,
476
497
  )
477
498
 
478
- num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
479
-
480
- # Update kvcache_num_blocks based on the attention implementation.
499
+ # Validate kvcache_num_blocks based on the number of full blocks required.
500
+ # Eager mode restriction:
501
+ # - num_blocks must be at least equal to the batch size
502
+ # Flash attention restriction:
503
+ # - num_blocks must be at least equal to (max_seq_len // kvcache_block_size) + 1
504
+ # - num_blocks must be no greater than the number of full blocks.
481
505
  if rbln_config.attn_impl == "flash_attn":
482
- estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
483
- model=model, model_config=model_config, rbln_config=rbln_config
484
- )
506
+ if rbln_config.is_auto_num_blocks:
507
+ # Do nothing
508
+ pass
485
509
 
486
- if rbln_config.kvcache_num_blocks is None:
487
- if estimated_max_num_blocks < num_full_blocks:
488
- # lower bound of the number of blocks for flash attention.
489
- min_blocks_for_flash = min(
490
- rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
510
+ else:
511
+ if rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
512
+ logger.warning(
513
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
514
+ f" than the required number of blocks ({rbln_config.num_full_blocks})."
515
+ "This can cause a failure during model compilation."
516
+ )
517
+ elif rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
518
+ raise ValueError(
519
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is less"
520
+ f" than the minimum number of blocks ({rbln_config.num_min_blocks})."
491
521
  )
492
- if min_blocks_for_flash > estimated_max_num_blocks:
493
- # NOTE: Just try to compile with lower bound of blocks for flash attention.
494
- # Even if it's larger than the estimated maximum number of blocks.
495
- rbln_config.kvcache_num_blocks = min_blocks_for_flash
496
- else:
497
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
498
- rbln_config.kvcache_num_blocks = estimated_max_num_blocks
499
-
500
- if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
501
- raise RuntimeError(
502
- f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
503
- "Ensure the number of blocks is at least equal to the batch size."
504
- )
505
- else:
506
- rbln_config.kvcache_num_blocks = num_full_blocks
507
- elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
508
- logger.warning(
509
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
510
- f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
511
- "This can cause a failure during model compilation."
512
- )
513
522
  else:
514
- if rbln_config.kvcache_num_blocks is None:
515
- rbln_config.kvcache_num_blocks = num_full_blocks
516
- elif rbln_config.kvcache_num_blocks > num_full_blocks:
523
+ if rbln_config.is_auto_num_blocks:
524
+ # Eager attention should use fixed number of blocks.
525
+ rbln_config.kvcache_num_blocks = rbln_config.num_full_blocks
526
+ elif rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
517
527
  logger.warning(
518
528
  f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
519
- f" than the required number of blocks ({num_full_blocks})."
529
+ f" than the required number of blocks ({rbln_config.num_full_blocks})."
520
530
  "This can cause a failure during model compilation."
521
531
  )
522
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
523
532
 
524
533
  return rbln_config
525
534
 
@@ -562,6 +571,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
562
571
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
563
572
  compile_cfgs = [prefill_compile_config]
564
573
 
574
+ if rbln_config.use_image_prefill:
575
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
576
+ raise NotImplementedError(
577
+ "Not implemented for different prefill chunk sizes between text and image prefill."
578
+ )
579
+ image_prefill_input_info = cls.get_input_info(
580
+ batch_size=1,
581
+ query_length=rbln_config.image_prefill_chunk_size,
582
+ rbln_config=rbln_config,
583
+ model_config=model_config,
584
+ )
585
+ image_prefill_compile_config = RBLNCompileConfig(
586
+ compiled_model_name="image_prefill", input_info=image_prefill_input_info
587
+ )
588
+ compile_cfgs.append(image_prefill_compile_config)
589
+
565
590
  if rbln_config.can_generate:
566
591
  for batch_size in rbln_config.decoder_batch_sizes:
567
592
  dec_input_info = cls.get_input_info(
@@ -583,36 +608,21 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
583
608
  compiled_models: List[rebel.RBLNCompiledModel],
584
609
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
585
610
  ) -> List[rebel.Runtime]:
586
- expected_model_names = ["prefill"]
587
- if rbln_config.can_generate:
588
- expected_model_names.extend(
589
- [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
590
- )
611
+ expected_model_names = rbln_config.expected_compiled_model_names
612
+
591
613
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
592
614
  cls._raise_missing_compiled_file_error(expected_model_names)
593
615
 
594
616
  ret_val = [
595
617
  rebel.Runtime(
596
- compiled_models[0],
618
+ compiled_models[i],
597
619
  tensor_type="pt",
598
- device=rbln_config.device_map["prefill"],
620
+ device=rbln_config.device_map[model_name],
599
621
  activate_profiler=rbln_config.activate_profiler,
600
622
  timeout=rbln_config.timeout,
601
623
  )
624
+ for i, model_name in enumerate(expected_model_names)
602
625
  ]
603
- if rbln_config.can_generate:
604
- ret_val.extend(
605
- [
606
- rebel.Runtime(
607
- compiled_models[i + 1],
608
- tensor_type="pt",
609
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
610
- activate_profiler=rbln_config.activate_profiler,
611
- timeout=rbln_config.timeout,
612
- )
613
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
614
- ]
615
- )
616
626
  return ret_val
617
627
 
618
628
  def forward(
@@ -643,15 +653,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
643
653
  raise ValueError(
644
654
  f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
645
655
  )
646
-
647
- output_hidden_states = (
648
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
649
- )
650
- if output_hidden_states != self.rbln_config.output_hidden_states:
651
- raise ValueError(
652
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
653
- f"Please compile again with the correct argument."
654
- )
656
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
655
657
 
656
658
  all_last_hidden_states = []
657
659
  all_hidden_states = (
@@ -660,7 +662,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
660
662
  self.rbln_config.batch_size,
661
663
  inputs.shape[1],
662
664
  self.config.hidden_size,
663
- dtype=self.rbln_config.torch_dtype,
665
+ dtype=self.rbln_config.dtype,
664
666
  )
665
667
  for _ in range(self.config.num_hidden_layers + 1)
666
668
  )
@@ -700,6 +702,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
700
702
  1. Converting pre-trained transformer models to RBLN-optimized format
701
703
  2. Handling the compilation process for RBLN devices
702
704
  3. Managing inference operations for causal language modeling
705
+
703
706
  This class inherits from RBLNModel and implements specific methods required for
704
707
  decoder-only architectures and causal language modeling tasks.
705
708
 
@@ -716,10 +719,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
716
719
  def logits_last_dim(self):
717
720
  return self.config.vocab_size
718
721
 
719
- @classmethod
720
- def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
721
- return is_prefill
722
-
723
722
  def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
724
723
  if isinstance(lora_int_ids, int):
725
724
  lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
@@ -803,14 +802,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
803
802
  )
804
803
  padded_cache_lengths = torch.zeros_like(generate_idx)
805
804
 
806
- output_hidden_states = (
807
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
808
- )
809
- if output_hidden_states != self.rbln_config.output_hidden_states:
810
- raise ValueError(
811
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
812
- f"Please compile again with the correct argument."
813
- )
805
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
814
806
 
815
807
  # Prefill
816
808
  if cache_position is None:
@@ -829,7 +821,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
829
821
 
830
822
  all_hidden_states = (
831
823
  tuple(
832
- torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
824
+ torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.dtype)
833
825
  for _ in range(self.config.num_hidden_layers + 1)
834
826
  )
835
827
  if self.rbln_config.output_hidden_states
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .configuration_detr import RBLNDetrForObjectDetectionConfig
17
+ from .modeling_detr import RBLNDetrForObjectDetection
18
+
19
+
20
+ __all__ = [
21
+ "RBLNDetrForObjectDetectionConfig",
22
+ "RBLNDetrForObjectDetection",
23
+ ]
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
17
+
18
+
19
+ class RBLNDetrForObjectDetectionConfig(RBLNModelForImageClassificationConfig):
20
+ """
21
+ Configuration class for RBLNDetrForObjectDetection.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized DETR models for object detection tasks.
25
+ """
26
+
27
+ def __init__(self, **kwargs):
28
+ """
29
+ Args:
30
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
31
+ Can be an integer for square images or a tuple (height, width).
32
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
33
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
34
+
35
+ Raises:
36
+ ValueError: If batch_size is not a positive integer.
37
+ """
38
+ super().__init__(**kwargs)
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING, Tuple, Union
17
+
18
+ import torch
19
+ from transformers.models.detr.modeling_detr import DetrObjectDetectionOutput
20
+
21
+ from ...modeling_generic import RBLNModelForImageClassification
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ pass
26
+
27
+
28
+ class RBLNDetrForObjectDetection(RBLNModelForImageClassification):
29
+ """
30
+ RBLN optimized DETR model for object detection tasks.
31
+
32
+ This class provides hardware-accelerated inference for DETR models
33
+ on RBLN devices, supporting object detection with detection heads
34
+ designed for object detection tasks.
35
+ """
36
+
37
+ def forward(
38
+ self, pixel_values: torch.Tensor, return_dict: bool = None, **kwargs
39
+ ) -> Union[Tuple, DetrObjectDetectionOutput]:
40
+ """
41
+ Foward pass for the RBLN-optimized DETR model for object detection.
42
+
43
+ Args:
44
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
45
+ return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
46
+
47
+ Returns:
48
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
49
+ """
50
+ output = self.model[0](pixel_values=pixel_values, **kwargs)
51
+ return DetrObjectDetectionOutput(
52
+ logits=output[0], pred_boxes=output[1], last_hidden_state=output[2], encoder_last_hidden_state=output[3]
53
+ )
@@ -18,9 +18,6 @@ import torch.nn as nn
18
18
 
19
19
  from ....utils import logging
20
20
  from ...models.decoderonly.decoderonly_architecture import (
21
- DecoderOnlyAttention,
22
- DecoderOnlyLayer,
23
- DecoderOnlyModel,
24
21
  DecoderOnlyWrapper,
25
22
  )
26
23
 
@@ -42,36 +39,3 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
42
39
 
43
40
  def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
44
41
  return causal_lm.transformer
45
-
46
- def get_rbln_attn_class(self):
47
- return ExaoneAttention
48
-
49
- def get_rbln_layer_class(self):
50
- return ExaoneLayer
51
-
52
- def get_rbln_model_class(self):
53
- return ExaoneModel
54
-
55
-
56
- class ExaoneModel(DecoderOnlyModel):
57
- def get_embedding(self) -> nn.Embedding:
58
- return self._original_mod.wte
59
-
60
- def get_last_layernorm(self) -> nn.LayerNorm:
61
- return self._original_mod.ln_f
62
-
63
-
64
- class ExaoneLayer(DecoderOnlyLayer):
65
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.ln_1
67
-
68
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
69
- return self._original_mod.ln_2
70
-
71
-
72
- class ExaoneAttention(DecoderOnlyAttention):
73
- def __post_init__(self):
74
- self.q_proj = self._original_mod.q_proj
75
- self.k_proj = self._original_mod.k_proj
76
- self.v_proj = self._original_mod.v_proj
77
- self.o_proj = self._original_mod.out_proj
@@ -24,4 +24,4 @@ class GemmaWrapper(DecoderOnlyWrapper):
24
24
  class GemmaModel(DecoderOnlyModel):
25
25
  @property
26
26
  def hidden_multiplier(self):
27
- return self._original_mod.config.hidden_size**0.5
27
+ return self.config.hidden_size**0.5
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gemma2 import RBLNGemma2ForCausalLMConfig, RBLNGemma2ModelConfig
16
+ from .modeling_gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2Model