optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -26,15 +26,16 @@ from transformers.modeling_utils import no_init_weights
26
26
  from ....configuration_utils import RBLNCompileConfig
27
27
  from ....modeling import RBLNModel
28
28
  from ....utils.logging import get_logger
29
+ from ....utils.runtime_utils import is_compiler_supports_buffer_resize
29
30
  from ...modeling_attention_utils import (
30
31
  RBLNDecoderOnlyFlashAttentionMixin,
31
32
  set_default_values,
32
33
  validate_attention_method,
33
34
  validate_sliding_window,
34
35
  )
35
- from ...modeling_outputs import RBLNDecoderOnlyOutput
36
+ from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
36
37
  from ...utils.rbln_quantization import get_quantized_model
37
- from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
+ from .configuration_decoderonly import KVCacheMeta, RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
39
  from .decoderonly_architecture import DecoderOnlyWrapper
39
40
  from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
41
  from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
@@ -230,7 +231,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
230
231
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
231
232
  quantization=None,
232
233
  phase: str = "prefill",
233
- ):
234
+ ) -> rebel.RBLNCompiledModel:
234
235
  try:
235
236
  wrapped_model.phase = phase
236
237
  if quantization:
@@ -252,21 +253,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
252
253
  quantization.maybe_reset_quantization_env()
253
254
 
254
255
  @classmethod
255
- def _get_compile_context(
256
- cls,
257
- compile_config: RBLNCompileConfig,
258
- example_inputs: List[torch.Tensor],
259
- ):
256
+ def _get_compile_context(cls, compile_config: RBLNCompileConfig, example_inputs: List[torch.Tensor]):
260
257
  context = CompileContext(use_weight_sharing=True)
261
258
 
262
259
  # Mark static tensors (self kv states)
263
260
  static_tensors = {}
264
- idx = 0
265
261
  for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
266
262
  if "past_key_values" in name:
267
263
  static_tensors[name] = tensor
268
- context.mark_static_address(tensor, f"kv_cache_{idx}")
269
- idx += 1
264
+ context.mark_static_address(tensor, name)
270
265
 
271
266
  return context, static_tensors
272
267
 
@@ -281,7 +276,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
281
276
  prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
282
277
  context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
283
278
 
284
- compiled_models = {}
279
+ compiled_models: dict[str, rebel.RBLNCompiledModel] = {}
285
280
  compiled_models["prefill"] = cls._compile_model(
286
281
  wrapped_model,
287
282
  prefill_compile_config,
@@ -307,14 +302,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
307
302
  )
308
303
  compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
309
304
 
310
- # check if the memory is enough to have additional blocks
311
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
312
- if rbln_config.kvcache_num_blocks < required_num_blocks:
313
- cls.maybe_suggest_kvcache_num_blocks(
314
- compiled_models=compiled_models,
315
- model_config=model.config,
316
- rbln_config=rbln_config,
317
- )
305
+ if rbln_config.is_auto_num_blocks:
306
+ if not is_compiler_supports_buffer_resize():
307
+ raise RuntimeError("`kvcache_num_blocks` must be set.")
308
+ cls.set_kvcache_num_blocks_after_compilation(compiled_models, rbln_config)
318
309
 
319
310
  return compiled_models
320
311
 
@@ -330,8 +321,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
330
321
  return model
331
322
 
332
323
  @classmethod
333
- def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
334
- return use_local_attention
324
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True, logits_to_keep: int = None):
325
+ return is_prefill and (use_local_attention or logits_to_keep == 1)
335
326
 
336
327
  @classmethod
337
328
  def get_input_info(
@@ -350,7 +341,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
350
341
 
351
342
  input_info = []
352
343
  if rbln_config.use_inputs_embeds:
353
- input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
344
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.dtype))
354
345
  else:
355
346
  input_info.append(("input_ids", [batch_size, query_length], "int64"))
356
347
 
@@ -364,15 +355,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
364
355
  if rbln_config.use_local_attention:
365
356
  input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
366
357
 
367
- if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
358
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill, rbln_config.logits_to_keep):
368
359
  input_info.append(("query_position", [], "int16"))
369
360
 
370
361
  if rbln_config.use_attention_mask:
371
362
  if rbln_config.use_position_ids:
372
- input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
363
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.dtype))
373
364
  else:
374
365
  input_info.append(
375
- ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
366
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.dtype)
376
367
  )
377
368
 
378
369
  if rbln_config.use_position_ids:
@@ -381,29 +372,36 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
381
372
  if rbln_config.use_lora:
382
373
  input_info.append(("lora_int_ids", [batch_size], "int32"))
383
374
 
384
- kvcache_dtype = rbln_config.torch_dtype
385
- if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
386
- kvcache_dtype = "float8_e4m3fn"
375
+ if len(rbln_config.kvcache_metas) > 0:
376
+ # Meta is already set, use it
377
+ input_info.extend(
378
+ [
379
+ (kvcache_meta.name, kvcache_meta.compile_shape, kvcache_meta.dtype)
380
+ for kvcache_meta in rbln_config.kvcache_metas
381
+ ]
382
+ )
387
383
 
388
- global_kvcache_shape = [
389
- rbln_config.kvcache_num_blocks,
390
- num_key_value_heads,
391
- rbln_config.kvcache_block_size,
392
- head_dim,
393
- ]
394
- local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
395
- input_info.extend(
396
- [
397
- (
398
- f"past_key_values_{i}",
399
- local_kvcache_shape
400
- if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
401
- else global_kvcache_shape,
402
- kvcache_dtype,
384
+ else:
385
+ kvcache_dtype = rbln_config.dtype
386
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
387
+ kvcache_dtype = "float8_e4m3fn"
388
+
389
+ kvcache_metas = []
390
+ for i in range(num_hidden_layers * 2):
391
+ layer_idx = i // 2
392
+ name = f"past_key_values_{i}"
393
+ kvcache_meta = KVCacheMeta.make(
394
+ name,
395
+ layer_idx,
396
+ num_key_value_heads,
397
+ head_dim,
398
+ RBLNCompileConfig.normalize_dtype(kvcache_dtype),
399
+ rbln_config,
403
400
  )
404
- for i in range(num_hidden_layers * 2)
405
- ]
406
- )
401
+ kvcache_metas.append(kvcache_meta)
402
+ input_info.append((name, kvcache_meta.compile_shape, kvcache_meta.dtype))
403
+
404
+ rbln_config.kvcache_metas.extend(kvcache_metas)
407
405
 
408
406
  return input_info
409
407
 
@@ -475,51 +473,39 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
475
473
  max_seq_len=rbln_config.max_seq_len,
476
474
  )
477
475
 
478
- num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
479
-
480
- # Update kvcache_num_blocks based on the attention implementation.
476
+ # Validate kvcache_num_blocks based on the number of full blocks required.
477
+ # Eager mode restriction:
478
+ # - num_blocks must be at least equal to the batch size
479
+ # Flash attention restriction:
480
+ # - num_blocks must be at least equal to (max_seq_len // kvcache_block_size) + 1
481
+ # - num_blocks must be no greater than the number of full blocks.
481
482
  if rbln_config.attn_impl == "flash_attn":
482
- estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
483
- model=model, model_config=model_config, rbln_config=rbln_config
484
- )
483
+ if rbln_config.is_auto_num_blocks:
484
+ # Do nothing
485
+ pass
485
486
 
486
- if rbln_config.kvcache_num_blocks is None:
487
- if estimated_max_num_blocks < num_full_blocks:
488
- # lower bound of the number of blocks for flash attention.
489
- min_blocks_for_flash = min(
490
- rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
487
+ else:
488
+ if rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
489
+ logger.warning(
490
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
491
+ f" than the required number of blocks ({rbln_config.num_full_blocks})."
492
+ "This can cause a failure during model compilation."
493
+ )
494
+ elif rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
495
+ raise ValueError(
496
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is less"
497
+ f" than the minimum number of blocks ({rbln_config.num_min_blocks})."
491
498
  )
492
- if min_blocks_for_flash > estimated_max_num_blocks:
493
- # NOTE: Just try to compile with lower bound of blocks for flash attention.
494
- # Even if it's larger than the estimated maximum number of blocks.
495
- rbln_config.kvcache_num_blocks = min_blocks_for_flash
496
- else:
497
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
498
- rbln_config.kvcache_num_blocks = estimated_max_num_blocks
499
-
500
- if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
501
- raise RuntimeError(
502
- f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
503
- "Ensure the number of blocks is at least equal to the batch size."
504
- )
505
- else:
506
- rbln_config.kvcache_num_blocks = num_full_blocks
507
- elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
508
- logger.warning(
509
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
510
- f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
511
- "This can cause a failure during model compilation."
512
- )
513
499
  else:
514
- if rbln_config.kvcache_num_blocks is None:
515
- rbln_config.kvcache_num_blocks = num_full_blocks
516
- elif rbln_config.kvcache_num_blocks > num_full_blocks:
500
+ if rbln_config.is_auto_num_blocks:
501
+ # Eager attention should use fixed number of blocks.
502
+ rbln_config.kvcache_num_blocks = rbln_config.num_full_blocks
503
+ elif rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
517
504
  logger.warning(
518
505
  f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
519
- f" than the required number of blocks ({num_full_blocks})."
506
+ f" than the required number of blocks ({rbln_config.num_full_blocks})."
520
507
  "This can cause a failure during model compilation."
521
508
  )
522
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
523
509
 
524
510
  return rbln_config
525
511
 
@@ -643,15 +629,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
643
629
  raise ValueError(
644
630
  f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
645
631
  )
646
-
647
- output_hidden_states = (
648
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
649
- )
650
- if output_hidden_states != self.rbln_config.output_hidden_states:
651
- raise ValueError(
652
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
653
- f"Please compile again with the correct argument."
654
- )
632
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
655
633
 
656
634
  all_last_hidden_states = []
657
635
  all_hidden_states = (
@@ -660,7 +638,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
660
638
  self.rbln_config.batch_size,
661
639
  inputs.shape[1],
662
640
  self.config.hidden_size,
663
- dtype=self.rbln_config.torch_dtype,
641
+ dtype=self.rbln_config.dtype,
664
642
  )
665
643
  for _ in range(self.config.num_hidden_layers + 1)
666
644
  )
@@ -700,6 +678,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
700
678
  1. Converting pre-trained transformer models to RBLN-optimized format
701
679
  2. Handling the compilation process for RBLN devices
702
680
  3. Managing inference operations for causal language modeling
681
+
703
682
  This class inherits from RBLNModel and implements specific methods required for
704
683
  decoder-only architectures and causal language modeling tasks.
705
684
 
@@ -716,10 +695,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
716
695
  def logits_last_dim(self):
717
696
  return self.config.vocab_size
718
697
 
719
- @classmethod
720
- def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
721
- return is_prefill
722
-
723
698
  def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
724
699
  if isinstance(lora_int_ids, int):
725
700
  lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
@@ -803,14 +778,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
803
778
  )
804
779
  padded_cache_lengths = torch.zeros_like(generate_idx)
805
780
 
806
- output_hidden_states = (
807
- output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
808
- )
809
- if output_hidden_states != self.rbln_config.output_hidden_states:
810
- raise ValueError(
811
- f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
812
- f"Please compile again with the correct argument."
813
- )
781
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
814
782
 
815
783
  # Prefill
816
784
  if cache_position is None:
@@ -829,7 +797,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
829
797
 
830
798
  all_hidden_states = (
831
799
  tuple(
832
- torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
800
+ torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.dtype)
833
801
  for _ in range(self.config.num_hidden_layers + 1)
834
802
  )
835
803
  if self.rbln_config.output_hidden_states
@@ -18,9 +18,6 @@ import torch.nn as nn
18
18
 
19
19
  from ....utils import logging
20
20
  from ...models.decoderonly.decoderonly_architecture import (
21
- DecoderOnlyAttention,
22
- DecoderOnlyLayer,
23
- DecoderOnlyModel,
24
21
  DecoderOnlyWrapper,
25
22
  )
26
23
 
@@ -42,36 +39,3 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
42
39
 
43
40
  def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
44
41
  return causal_lm.transformer
45
-
46
- def get_rbln_attn_class(self):
47
- return ExaoneAttention
48
-
49
- def get_rbln_layer_class(self):
50
- return ExaoneLayer
51
-
52
- def get_rbln_model_class(self):
53
- return ExaoneModel
54
-
55
-
56
- class ExaoneModel(DecoderOnlyModel):
57
- def get_embedding(self) -> nn.Embedding:
58
- return self._original_mod.wte
59
-
60
- def get_last_layernorm(self) -> nn.LayerNorm:
61
- return self._original_mod.ln_f
62
-
63
-
64
- class ExaoneLayer(DecoderOnlyLayer):
65
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.ln_1
67
-
68
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
69
- return self._original_mod.ln_2
70
-
71
-
72
- class ExaoneAttention(DecoderOnlyAttention):
73
- def __post_init__(self):
74
- self.q_proj = self._original_mod.q_proj
75
- self.k_proj = self._original_mod.k_proj
76
- self.v_proj = self._original_mod.v_proj
77
- self.o_proj = self._original_mod.out_proj
@@ -24,4 +24,4 @@ class GemmaWrapper(DecoderOnlyWrapper):
24
24
  class GemmaModel(DecoderOnlyModel):
25
25
  @property
26
26
  def hidden_multiplier(self):
27
- return self._original_mod.config.hidden_size**0.5
27
+ return self.config.hidden_size**0.5
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gemma2 import RBLNGemma2ForCausalLMConfig, RBLNGemma2ModelConfig
16
+ from .modeling_gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2Model
@@ -0,0 +1,45 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGemma2ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Gemma2 models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNGemma2ForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=8192,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNGemma2ForCausalLM.from_pretrained(
33
+ "google/gemma-2-9b",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
39
+
40
+
41
+ class RBLNGemma2ModelConfig(RBLNDecoderOnlyModelConfig):
42
+ """
43
+ Configuration class for RBLN Gemma2 models.
44
+ This class is an alias of RBLNDecoderOnlyModelConfig.
45
+ """
@@ -0,0 +1,83 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyModel
20
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
21
+
22
+
23
+ class Gemma2Wrapper(DecoderOnlyWrapper):
24
+ def get_rbln_layer_class(self):
25
+ return Gemma2DecoderLayer
26
+
27
+ def get_rbln_attn_class(self):
28
+ return Gemma2Attention
29
+
30
+ def get_rbln_model_class(self):
31
+ return Gemma2Model
32
+
33
+
34
+ class Gemma2DecoderLayer(DecoderOnlyLayer):
35
+ _PRE_FF_LAYERNORM_ATTRS = ["pre_feedforward_layernorm"]
36
+ _POST_FF_LAYERNORM_ATTRS = ["post_feedforward_layernorm"]
37
+
38
+ def forward(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ attention_mask: torch.Tensor,
42
+ seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
43
+ past_key_values: Tuple[Tuple[torch.Tensor]],
44
+ cos: Optional[torch.Tensor] = None,
45
+ sin: Optional[torch.Tensor] = None,
46
+ block_tables: Optional[torch.Tensor] = None,
47
+ lora_int_id: Optional[torch.Tensor] = None,
48
+ ):
49
+ residual = hidden_states
50
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
51
+
52
+ hidden_states = self.self_attn(
53
+ hidden_states=hidden_states,
54
+ attention_mask=attention_mask,
55
+ seq_positions=seq_positions,
56
+ past_key_values=past_key_values,
57
+ cos=cos,
58
+ sin=sin,
59
+ block_tables=block_tables,
60
+ lora_int_id=lora_int_id,
61
+ )
62
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
63
+ hidden_states = residual + hidden_states
64
+
65
+ # Fully Connected
66
+ residual = hidden_states
67
+ hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
68
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
69
+ hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
70
+ hidden_states = residual + hidden_states
71
+
72
+ return hidden_states
73
+
74
+
75
+ class Gemma2Attention(DecoderOnlyAttention):
76
+ def get_attn_scale(self, self_attn):
77
+ return self_attn.config.query_pre_attn_scalar**-0.5
78
+
79
+
80
+ class Gemma2Model(DecoderOnlyModel):
81
+ @property
82
+ def hidden_multiplier(self):
83
+ return self.config.hidden_size**0.5
@@ -0,0 +1,101 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ....utils import logging
17
+ from ...models.decoderonly import (
18
+ RBLNDecoderOnlyModel,
19
+ RBLNDecoderOnlyModelForCausalLM,
20
+ )
21
+ from .gemma2_architecture import Gemma2Wrapper
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class RBLNGemma2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
28
+ """
29
+ The Gemma2 Model transformer with a language modeling head (linear layer) on top.
30
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
31
+
32
+ A class to convert and run pre-trained transformers based Gemma2ForCausalLM model on RBLN devices.
33
+ It implements the methods to convert a pre-trained transformers Gemma2ForCausalLM model into a RBLN transformer model by:
34
+
35
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
36
+ - compiling the resulting graph using the RBLN compiler.
37
+
38
+ **Configuration:**
39
+ This model uses [`RBLNGemma2ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
40
+ the `rbln_config` parameter should be an instance of [`RBLNGemma2ForCausalLMConfig`] or a dictionary conforming to its structure.
41
+
42
+ See the [`RBLNGemma2ForCausalLMConfig`] class for all available configuration options.
43
+ Examples:
44
+ ```python
45
+ from optimum.rbln import RBLNGemma2ForCausalLM
46
+ # Simple usage using rbln_* arguments
47
+ # `max_seq_len` is automatically inferred from the model config
48
+ model = RBLNGemma2ForCausalLM.from_pretrained(
49
+ "google/gemma-2-9b",
50
+ export=True,
51
+ rbln_batch_size=1,
52
+ rbln_tensor_parallel_size=4,
53
+ )
54
+ # Using a config dictionary
55
+ rbln_config = {
56
+ "batch_size": 1,
57
+ "max_seq_len": 8192,
58
+ "tensor_parallel_size": 4,
59
+ }
60
+ model = RBLNGemma2ForCausalLM.from_pretrained(
61
+ "google/gemma-2-9b",
62
+ export=True,
63
+ rbln_config=rbln_config
64
+ )
65
+ # Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
66
+ from optimum.rbln import RBLNGemma2ForCausalLMConfig
67
+ config = RBLNGemma2ForCausalLMConfig(
68
+ batch_size=1,
69
+ max_seq_len=8192,
70
+ tensor_parallel_size=4
71
+ )
72
+ model = RBLNGemma2ForCausalLM.from_pretrained(
73
+ "google/gemma-2-9b",
74
+ export=True,
75
+ rbln_config=config
76
+ )
77
+ ```
78
+ """
79
+
80
+ _decoder_wrapper_cls = Gemma2Wrapper
81
+
82
+
83
+ class RBLNGemma2Model(RBLNDecoderOnlyModel):
84
+ """
85
+ The Gemma2 Model transformer without a language modeling head.
86
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
87
+
88
+ A class to convert and run pre-trained transformers based Gemma2Model model on RBLN devices.
89
+ It implements the methods to convert a pre-trained transformers Gemma2Model model into a RBLN transformer model by:
90
+
91
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
92
+ - compiling the resulting graph using the RBLN compiler.
93
+
94
+ **Configuration:**
95
+ This model uses [`RBLNGemma2ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
96
+ the `rbln_config` parameter should be an instance of [`RBLNGemma2ModelConfig`] or a dictionary conforming to its structure.
97
+
98
+ See the [`RBLNGemma2ModelConfig`] class for all available configuration options.
99
+ """
100
+
101
+ _decoder_wrapper_cls = Gemma2Wrapper