optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -26,13 +26,15 @@ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, Pre
26
26
  from transformers.modeling_utils import no_init_weights
27
27
  from transformers.utils import ModelOutput
28
28
 
29
+ from ....configuration_utils import RBLNCompileConfig
29
30
  from ....modeling import RBLNModel
30
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
33
  from ...utils.rbln_quantization import QuantizationManager
34
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
34
35
  from .decoderonly_architecture import (
35
36
  DecoderOnlyWrapper,
37
+ set_default_values,
36
38
  validate_attention_method,
37
39
  )
38
40
 
@@ -356,17 +358,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
356
358
  _use_rotary_emb = True
357
359
 
358
360
  def __post_init__(self, **kwargs):
359
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
360
- self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
361
- self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
362
- self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
363
- # FIXME get kvcache_num_blocks from compiled results.
364
- self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
365
- self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
366
- attn_impl = self.rbln_config.model_cfg["attn_impl"]
367
361
  main_input_name = self.main_input_name
368
362
 
369
- if self.rbln_config.model_cfg["use_inputs_embeds"]:
363
+ if self.rbln_config.use_inputs_embeds:
370
364
  main_input_name = "inputs_embeds"
371
365
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
372
366
  with no_init_weights():
@@ -380,40 +374,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
380
374
  self.embed_tokens = None
381
375
 
382
376
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
383
- dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
377
+ dec_attn_mask = torch.zeros(
378
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
379
+ )
384
380
  block_tables = torch.zeros(
385
- self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
381
+ self.rbln_config.batch_size,
382
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
383
+ dtype=torch.int16,
386
384
  ).fill_(-1)
387
- free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
385
+ free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
388
386
 
389
387
  self.prefill_decoder = RBLNRuntimeModel(
390
388
  runtime=self.model[0],
391
389
  main_input_name=main_input_name,
392
390
  embed_tokens=self.embed_tokens,
393
391
  phase="prefill",
394
- batch_size=self.batch_size,
392
+ batch_size=self.rbln_config.batch_size,
395
393
  dec_attn_mask=dec_attn_mask,
396
394
  block_tables=block_tables,
397
395
  free_block_pool=free_block_pool,
398
- kvcache_block_size=self.kvcache_block_size,
396
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
399
397
  vocab_size=self.config.vocab_size,
400
- prefill_chunk_size=self.prefill_chunk_size,
401
- max_seq_len=self.max_seq_len,
402
- use_attention_mask=self.use_attention_mask,
403
- attn_impl=attn_impl,
398
+ prefill_chunk_size=self.rbln_config.prefill_chunk_size,
399
+ max_seq_len=self.rbln_config.max_seq_len,
400
+ use_attention_mask=self.rbln_config.use_attention_mask,
401
+ attn_impl=self.rbln_config.attn_impl,
404
402
  )
405
403
  self.decoder = RBLNRuntimeModel(
406
404
  runtime=self.model[1],
407
405
  main_input_name=main_input_name,
408
406
  embed_tokens=self.embed_tokens,
409
407
  phase="decode",
410
- batch_size=self.batch_size,
408
+ batch_size=self.rbln_config.batch_size,
411
409
  dec_attn_mask=dec_attn_mask,
412
410
  block_tables=block_tables,
413
411
  free_block_pool=free_block_pool,
414
- kvcache_block_size=self.kvcache_block_size,
415
- use_attention_mask=self.use_attention_mask,
416
- attn_impl=attn_impl,
412
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
413
+ use_attention_mask=self.rbln_config.use_attention_mask,
414
+ attn_impl=self.rbln_config.attn_impl,
417
415
  )
418
416
 
419
417
  @classmethod
@@ -422,13 +420,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
422
420
  model: "PreTrainedModel",
423
421
  save_dir_path: Path,
424
422
  subfolder: str,
425
- rbln_config: RBLNConfig,
423
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
426
424
  ):
427
425
  """
428
426
  If you are unavoidably running on a CPU rather than an RBLN device,
429
427
  store the torch tensor, weight, etc. in this function.
430
428
  """
431
- if rbln_config.model_cfg["use_inputs_embeds"]:
429
+ if rbln_config.use_inputs_embeds:
432
430
  save_dict = {}
433
431
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
434
432
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -493,33 +491,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
493
491
  return val
494
492
 
495
493
  @classmethod
496
- def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
497
- logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
498
-
499
- rbln_kwargs = kwargs.get("rbln_kwargs", {})
500
- rbln_quantization = rbln_kwargs.get("quantization", None)
501
- if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
494
+ def get_pytorch_model(
495
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
496
+ ) -> "PreTrainedModel":
497
+ if (
498
+ rbln_config is not None
499
+ and "format" in rbln_config.quantization
500
+ and rbln_config.quantization["format"] == "rbln"
501
+ ):
502
502
  model = cls.get_quantized_model(*args, **kwargs)
503
503
  else:
504
504
  model = super().get_pytorch_model(*args, **kwargs)
505
505
 
506
- logger.debug("Loaded the LLM model to the CPU.")
507
506
  return model
508
507
 
509
508
  @classmethod
510
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
511
- wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
512
- wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
513
- wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
514
- wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
515
- wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
516
- wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
517
-
509
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
510
+ wrapper_cfg = {
511
+ "max_seq_len": rbln_config.max_seq_len,
512
+ "attn_impl": rbln_config.attn_impl,
513
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
514
+ "kvcache_block_size": rbln_config.kvcache_block_size,
515
+ "use_rotary_emb": cls._use_rotary_emb,
516
+ "use_attention_mask": rbln_config.use_attention_mask,
517
+ }
518
518
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
519
519
 
520
520
  @classmethod
521
521
  @torch.inference_mode()
522
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
522
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
523
523
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
524
524
 
525
525
  rbln_compile_configs = rbln_config.compile_cfgs
@@ -541,8 +541,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
541
541
 
542
542
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
543
543
 
544
- quantize_config = rbln_config.model_cfg.get("quantization", None)
545
-
546
544
  @QuantizationManager.with_quantization_env
547
545
  def compile_model(*args, **kwargs):
548
546
  try:
@@ -567,7 +565,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
567
565
  finally:
568
566
  torch.nn.functional.linear = original_linear
569
567
 
570
- return compile_model(quantize_config=quantize_config)
568
+ return compile_model(quantize_config=rbln_config.quantization)
571
569
 
572
570
  @classmethod
573
571
  def get_maximum_num_blocks(
@@ -654,78 +652,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
654
652
  return max_n_blocks
655
653
 
656
654
  @classmethod
657
- def _get_rbln_config(
655
+ def _update_rbln_config(
658
656
  cls,
659
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
660
- model_config: "PretrainedConfig",
661
- rbln_kwargs: Dict[str, Any] = {},
662
- ) -> RBLNConfig:
663
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
664
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
665
- rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
666
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
667
- rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
668
- rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
669
- rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
670
- rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
671
- rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
672
-
673
- if rbln_use_attention_mask is None:
674
- rbln_use_attention_mask = False
675
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
676
- if rbln_npu == "RBLN-CA02":
677
- rbln_use_attention_mask = True
678
-
679
- if rbln_prefill_chunk_size is None:
680
- rbln_prefill_chunk_size = 128
681
- elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
682
- raise ValueError(
683
- f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
684
- )
685
-
686
- if rbln_max_seq_len is None:
687
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
657
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
658
+ model: Optional["PreTrainedModel"] = None,
659
+ model_config: Optional["PretrainedConfig"] = None,
660
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
661
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
662
+ if rbln_config.max_seq_len is None:
663
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
688
664
  model_config, "n_positions", None
689
665
  )
690
- if rbln_max_seq_len is None:
691
- raise ValueError("`rbln_max_seq_len` should be specified.")
692
-
693
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
694
- rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
666
+ if rbln_config.max_seq_len is None:
667
+ raise ValueError("`max_seq_len` should be specified.")
668
+
669
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
670
+ attn_impl=rbln_config.attn_impl,
671
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
672
+ kvcache_block_size=rbln_config.kvcache_block_size,
673
+ max_seq_len=rbln_config.max_seq_len,
674
+ )
695
675
 
696
- rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
697
- rbln_attn_impl=rbln_attn_impl,
698
- rbln_kvcache_partition_len=rbln_kvcache_partition_len,
699
- rbln_kvcache_block_size=rbln_kvcache_block_size,
700
- rbln_max_seq_len=rbln_max_seq_len,
676
+ validate_attention_method(
677
+ attn_impl=rbln_config.attn_impl,
678
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
679
+ kvcache_block_size=rbln_config.kvcache_block_size,
680
+ max_seq_len=rbln_config.max_seq_len,
701
681
  )
702
682
 
703
- if rbln_kvcache_block_size is None:
704
- if rbln_attn_impl == "eager":
705
- rbln_kvcache_block_size = rbln_max_seq_len
706
- else:
707
- rbln_kvcache_block_size = rbln_kvcache_partition_len
683
+ rbln_config.kvcache_num_blocks = (
684
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size
685
+ ) * rbln_config.batch_size
708
686
 
709
- rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
710
- if rbln_attn_impl == "flash_attn":
687
+ if rbln_config.attn_impl == "flash_attn":
711
688
  max_num_blocks = cls.get_maximum_num_blocks(
712
689
  config=model_config,
713
- tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
714
- kvcache_block_size=rbln_kvcache_block_size,
715
- nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
716
- n_model_params=rbln_kwargs["n_model_params"],
690
+ tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
691
+ kvcache_block_size=rbln_config.kvcache_block_size,
692
+ nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
693
+ n_model_params=sum(p.numel() for p in model.parameters()),
717
694
  )
718
- rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
695
+ rbln_config.kvcache_num_blocks = min(rbln_config.kvcache_num_blocks, max_num_blocks)
719
696
 
720
- required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
721
- if rbln_kvcache_num_blocks < required_blocks:
722
- rbln_kvcache_num_blocks = required_blocks
697
+ required_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
698
+ if rbln_config.kvcache_num_blocks < required_blocks:
699
+ rbln_config.kvcache_num_blocks = required_blocks
723
700
 
724
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
701
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
725
702
 
726
- if rbln_kvcache_num_blocks < rbln_batch_size:
703
+ if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
727
704
  raise RuntimeError(
728
- f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
705
+ f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({rbln_config.kvcache_num_blocks}). "
729
706
  "Ensure the number of blocks is at least equal to the batch size."
730
707
  )
731
708
 
@@ -740,6 +717,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
740
717
  query_length,
741
718
  use_inputs_embeds,
742
719
  hidden_size,
720
+ use_attention_mask,
721
+ max_seq_len,
722
+ kvcache_block_size,
723
+ kvcache_num_blocks,
743
724
  ):
744
725
  if use_inputs_embeds:
745
726
  main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
@@ -755,10 +736,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
755
736
  ),
756
737
  ]
757
738
 
758
- if rbln_use_attention_mask:
739
+ if use_attention_mask:
759
740
  input_info.extend(
760
741
  [
761
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
742
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
762
743
  ]
763
744
  )
764
745
 
@@ -769,7 +750,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
769
750
  ]
770
751
  )
771
752
 
772
- max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
753
+ max_block_cnt = max_seq_len // kvcache_block_size
773
754
 
774
755
  if query_length > 1:
775
756
  input_info.extend([("block_tables", [max_block_cnt], "int16")])
@@ -781,9 +762,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
781
762
  (
782
763
  f"past_key_values_{i}",
783
764
  [
784
- rbln_kvcache_num_blocks,
765
+ kvcache_num_blocks,
785
766
  num_key_value_heads,
786
- rbln_kvcache_block_size,
767
+ kvcache_block_size,
787
768
  head_dim,
788
769
  ],
789
770
  "float32",
@@ -796,42 +777,29 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
796
777
 
797
778
  prefill_input_info = get_input_info(
798
779
  batch_size=1,
799
- query_length=rbln_prefill_chunk_size,
800
- use_inputs_embeds=rbln_use_inputs_embeds,
780
+ query_length=rbln_config.prefill_chunk_size,
781
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
801
782
  hidden_size=hidden_size,
783
+ use_attention_mask=rbln_config.use_attention_mask,
784
+ max_seq_len=rbln_config.max_seq_len,
785
+ kvcache_block_size=rbln_config.kvcache_block_size,
786
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
802
787
  )
803
788
  dec_input_info = get_input_info(
804
- batch_size=rbln_batch_size,
789
+ batch_size=rbln_config.batch_size,
805
790
  query_length=1,
806
- use_inputs_embeds=rbln_use_inputs_embeds,
791
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
807
792
  hidden_size=hidden_size,
793
+ use_attention_mask=rbln_config.use_attention_mask,
794
+ max_seq_len=rbln_config.max_seq_len,
795
+ kvcache_block_size=rbln_config.kvcache_block_size,
796
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
808
797
  )
809
798
 
810
799
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
811
800
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
812
801
 
813
- rbln_config = RBLNConfig(
814
- rbln_cls=cls.__name__,
815
- compile_cfgs=[prefill_compile_config, dec_compile_config],
816
- rbln_kwargs=rbln_kwargs,
817
- )
818
-
819
- rbln_config.model_cfg.update(
820
- {
821
- "max_seq_len": rbln_max_seq_len,
822
- "batch_size": rbln_batch_size,
823
- "prefill_chunk_size": rbln_prefill_chunk_size,
824
- "use_attention_mask": rbln_use_attention_mask,
825
- "use_inputs_embeds": rbln_use_inputs_embeds,
826
- "kvcache_partition_len": rbln_kvcache_partition_len,
827
- "kvcache_block_size": rbln_kvcache_block_size,
828
- "attn_impl": rbln_attn_impl,
829
- "kvcache_num_blocks": rbln_kvcache_num_blocks,
830
- }
831
- )
832
-
833
- if rbln_quantization is not None:
834
- rbln_config.model_cfg.update({"quantization": rbln_quantization})
802
+ rbln_config.set_compile_cfgs([prefill_compile_config, dec_compile_config])
835
803
 
836
804
  return rbln_config
837
805
 
@@ -839,18 +807,23 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
839
807
  def _create_runtimes(
840
808
  cls,
841
809
  compiled_models: List[rebel.RBLNCompiledModel],
842
- rbln_device_map: Dict[str, int],
843
- activate_profiler: Optional[bool] = None,
810
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
844
811
  ) -> List[rebel.Runtime]:
845
- if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
812
+ if any(model_name not in rbln_config.device_map for model_name in ["prefill", "decoder"]):
846
813
  cls._raise_missing_compiled_file_error(["prefill", "decoder"])
847
814
 
848
815
  return [
849
- compiled_models[0].create_runtime(
850
- tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
816
+ rebel.Runtime(
817
+ compiled_models[0],
818
+ tensor_type="pt",
819
+ device=rbln_config.device_map["prefill"],
820
+ activate_profiler=rbln_config.activate_profiler,
851
821
  ),
852
- compiled_models[1].create_runtime(
853
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
822
+ rebel.Runtime(
823
+ compiled_models[1],
824
+ tensor_type="pt",
825
+ device=rbln_config.device_map["decoder"],
826
+ activate_profiler=rbln_config.activate_profiler,
854
827
  ),
855
828
  ]
856
829
 
@@ -887,7 +860,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
887
860
  model_inputs.update({"input_ids": input_ids})
888
861
 
889
862
  if inputs_embeds is not None:
890
- if self.rbln_config.model_cfg["use_inputs_embeds"]:
863
+ if self.rbln_config.use_inputs_embeds:
891
864
  model_inputs.update({"inputs_embeds": inputs_embeds})
892
865
  else:
893
866
  raise ValueError(
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_dpt import RBLNDPTForDepthEstimationConfig
15
16
  from .modeling_dpt import RBLNDPTForDepthEstimation
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForDepthEstimationConfig
16
+
17
+
18
+ class RBLNDPTForDepthEstimationConfig(RBLNModelForDepthEstimationConfig):
19
+ pass
@@ -12,82 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
16
15
 
17
- from transformers import AutoModelForDepthEstimation
18
- from transformers.modeling_outputs import DepthEstimatorOutput
16
+ from ...modeling_generic import RBLNModelForDepthEstimation
19
17
 
20
- from ....modeling import RBLNModel
21
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
22
- from ....utils.logging import get_logger
23
18
 
24
-
25
- logger = get_logger(__name__)
26
-
27
- if TYPE_CHECKING:
28
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
29
-
30
-
31
- class RBLNDPTForDepthEstimation(RBLNModel):
32
- auto_model_class = AutoModelForDepthEstimation
33
- main_input_name = "pixel_values"
34
-
35
- @classmethod
36
- def _get_rbln_config(
37
- cls,
38
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
39
- model_config: Optional["PretrainedConfig"] = None,
40
- rbln_kwargs: Dict[str, Any] = {},
41
- ) -> RBLNConfig:
42
- rbln_image_size = rbln_kwargs.get("image_size", None)
43
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
44
-
45
- if rbln_batch_size is None:
46
- rbln_batch_size = 1
47
-
48
- if rbln_image_size is None:
49
- for processor in preprocessors:
50
- image_size = getattr(processor, "size", None)
51
-
52
- if image_size is not None:
53
- if isinstance(image_size, Iterable):
54
- if "shortest_edge" in image_size:
55
- rbln_image_size = image_size["shortest_edge"]
56
- break
57
- elif "height" in image_size and "width" in image_size:
58
- rbln_image_size = image_size["height"], image_size["width"]
59
- break
60
- else:
61
- rbln_image_size = image_size
62
-
63
- if rbln_image_size is None:
64
- rbln_image_size = getattr(model_config, "image_size", None)
65
-
66
- if rbln_image_size is None:
67
- raise ValueError("`rbln_image_size` should be specified!")
68
-
69
- if isinstance(rbln_image_size, int):
70
- rbln_image_size = rbln_image_size, rbln_image_size
71
-
72
- input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size[0], rbln_image_size[1]], "float32")]
73
-
74
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
75
-
76
- rbln_config = RBLNConfig(
77
- rbln_cls=cls.__name__,
78
- compile_cfgs=[rbln_compile_config],
79
- rbln_kwargs=rbln_kwargs,
80
- )
81
-
82
- rbln_config.model_cfg.update(
83
- {
84
- "image_size": rbln_image_size,
85
- "batch_size": rbln_batch_size,
86
- }
87
- )
88
-
89
- return rbln_config
90
-
91
- def forward(self, *args, **kwargs):
92
- predicted_depth = super().forward(*args, **kwargs)
93
- return DepthEstimatorOutput(predicted_depth=predicted_depth)
19
+ class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
20
+ pass
@@ -20,4 +20,5 @@ this_path = os.path.abspath(__file__)
20
20
  local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
21
21
  environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
22
22
 
23
+ from .configuration_exaone import RBLNExaoneForCausalLMConfig
23
24
  from .modeling_exaone import RBLNExaoneForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNExaoneForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_gemma import RBLNGemmaForCausalLMConfig
15
16
  from .modeling_gemma import RBLNGemmaForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig
15
16
  from .modeling_gpt2 import RBLNGPT2LMHeadModel
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_llama import RBLNLlamaForCausalLMConfig
15
16
  from .modeling_llama import RBLNLlamaForCausalLM
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ pass