optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -91,6 +91,10 @@ _import_structure = {
91
91
  "RBLNGemmaModel",
92
92
  "RBLNGemmaModelConfig",
93
93
  "RBLNGemmaForCausalLM",
94
+ "RBLNGemma2ForCausalLM",
95
+ "RBLNGemma2ForCausalLMConfig",
96
+ "RBLNGemma2Model",
97
+ "RBLNGemma2ModelConfig",
94
98
  "RBLNGemmaForCausalLMConfig",
95
99
  "RBLNGemma3ForCausalLM",
96
100
  "RBLNGemma3ForCausalLMConfig",
@@ -100,6 +104,8 @@ _import_structure = {
100
104
  "RBLNGPT2ModelConfig",
101
105
  "RBLNGPT2LMHeadModel",
102
106
  "RBLNGPT2LMHeadModelConfig",
107
+ "RBLNGptOssForCausalLM",
108
+ "RBLNGptOssForCausalLMConfig",
103
109
  "RBLNGroundingDinoDecoder",
104
110
  "RBLNGroundingDinoDecoderConfig",
105
111
  "RBLNGroundingDinoForObjectDetection",
@@ -140,14 +146,24 @@ _import_structure = {
140
146
  "RBLNPixtralVisionModelConfig",
141
147
  "RBLNPhiModel",
142
148
  "RBLNPhiModelConfig",
149
+ "RBLNPaliGemmaForConditionalGeneration",
150
+ "RBLNPaliGemmaForConditionalGenerationConfig",
151
+ "RBLNPaliGemmaModel",
152
+ "RBLNPaliGemmaModelConfig",
143
153
  "RBLNQwen2ForCausalLM",
144
154
  "RBLNQwen2ForCausalLMConfig",
145
155
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
146
156
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
147
157
  "RBLNQwen2_5_VLForConditionalGeneration",
148
158
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
159
+ "RBLNQwen3MoeForCausalLM",
160
+ "RBLNQwen3MoeForCausalLMConfig",
161
+ "RBLNQwen2_5_VLModel",
162
+ "RBLNQwen2_5_VLModelConfig",
149
163
  "RBLNQwen2Model",
150
164
  "RBLNQwen2ModelConfig",
165
+ "RBLNQwen2MoeForCausalLM",
166
+ "RBLNQwen2MoeForCausalLMConfig",
151
167
  "RBLNQwen3ForCausalLM",
152
168
  "RBLNQwen3ForCausalLMConfig",
153
169
  "RBLNQwen3Model",
@@ -156,6 +172,8 @@ _import_structure = {
156
172
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
157
173
  "RBLNQwen2VLForConditionalGeneration",
158
174
  "RBLNQwen2VLForConditionalGenerationConfig",
175
+ "RBLNQwen2VLModel",
176
+ "RBLNQwen2VLModelConfig",
159
177
  "RBLNResNetForImageClassification",
160
178
  "RBLNResNetForImageClassificationConfig",
161
179
  "RBLNRobertaForMaskedLM",
@@ -394,6 +412,10 @@ if TYPE_CHECKING:
394
412
  RBLNDPTForDepthEstimationConfig,
395
413
  RBLNExaoneForCausalLM,
396
414
  RBLNExaoneForCausalLMConfig,
415
+ RBLNGemma2ForCausalLM,
416
+ RBLNGemma2ForCausalLMConfig,
417
+ RBLNGemma2Model,
418
+ RBLNGemma2ModelConfig,
397
419
  RBLNGemma3ForCausalLM,
398
420
  RBLNGemma3ForCausalLMConfig,
399
421
  RBLNGemma3ForConditionalGeneration,
@@ -406,6 +428,8 @@ if TYPE_CHECKING:
406
428
  RBLNGPT2LMHeadModelConfig,
407
429
  RBLNGPT2Model,
408
430
  RBLNGPT2ModelConfig,
431
+ RBLNGptOssForCausalLM,
432
+ RBLNGptOssForCausalLMConfig,
409
433
  RBLNGroundingDinoDecoder,
410
434
  RBLNGroundingDinoDecoderConfig,
411
435
  RBLNGroundingDinoEncoder,
@@ -436,6 +460,10 @@ if TYPE_CHECKING:
436
460
  RBLNOPTForCausalLMConfig,
437
461
  RBLNOPTModel,
438
462
  RBLNOPTModelConfig,
463
+ RBLNPaliGemmaForConditionalGeneration,
464
+ RBLNPaliGemmaForConditionalGenerationConfig,
465
+ RBLNPaliGemmaModel,
466
+ RBLNPaliGemmaModelConfig,
439
467
  RBLNPegasusForConditionalGeneration,
440
468
  RBLNPegasusForConditionalGenerationConfig,
441
469
  RBLNPegasusModel,
@@ -450,18 +478,26 @@ if TYPE_CHECKING:
450
478
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
451
479
  RBLNQwen2_5_VLForConditionalGeneration,
452
480
  RBLNQwen2_5_VLForConditionalGenerationConfig,
481
+ RBLNQwen2_5_VLModel,
482
+ RBLNQwen2_5_VLModelConfig,
453
483
  RBLNQwen2ForCausalLM,
454
484
  RBLNQwen2ForCausalLMConfig,
455
485
  RBLNQwen2Model,
456
486
  RBLNQwen2ModelConfig,
487
+ RBLNQwen2MoeForCausalLM,
488
+ RBLNQwen2MoeForCausalLMConfig,
457
489
  RBLNQwen2VisionTransformerPretrainedModel,
458
490
  RBLNQwen2VisionTransformerPretrainedModelConfig,
459
491
  RBLNQwen2VLForConditionalGeneration,
460
492
  RBLNQwen2VLForConditionalGenerationConfig,
493
+ RBLNQwen2VLModel,
494
+ RBLNQwen2VLModelConfig,
461
495
  RBLNQwen3ForCausalLM,
462
496
  RBLNQwen3ForCausalLMConfig,
463
497
  RBLNQwen3Model,
464
498
  RBLNQwen3ModelConfig,
499
+ RBLNQwen3MoeForCausalLM,
500
+ RBLNQwen3MoeForCausalLMConfig,
465
501
  RBLNResNetForImageClassification,
466
502
  RBLNResNetForImageClassificationConfig,
467
503
  RBLNRobertaForMaskedLM,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.9.4a2'
32
- __version_tuple__ = version_tuple = (0, 9, 4, 'a2')
31
+ __version__ = version = '0.9.5a4'
32
+ __version_tuple__ = version_tuple = (0, 9, 5, 'a4')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -24,7 +24,7 @@ import torch
24
24
  from packaging.version import Version
25
25
 
26
26
  from .__version__ import __version__
27
- from .utils.deprecation import warn_deprecated_npu
27
+ from .utils.deprecation import deprecate_kwarg, warn_deprecated_npu
28
28
  from .utils.logging import get_logger
29
29
  from .utils.runtime_utils import ContextRblnConfig
30
30
 
@@ -92,7 +92,7 @@ class RBLNCompileConfig:
92
92
  and isinstance(item[0], str) # name
93
93
  and isinstance(item[1], (tuple, list)) # shape
94
94
  and all(isinstance(x, int) for x in item[1])
95
- and isinstance(item[2], str) # dtype
95
+ and (isinstance(item[2], str) or isinstance(item[2], torch.dtype)) # dtype
96
96
  for item in input_info
97
97
  )
98
98
 
@@ -524,8 +524,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
524
524
  non_save_attributes = [
525
525
  "_frozen",
526
526
  "_runtime_options",
527
- "torch_dtype",
528
527
  "npu",
528
+ "dtype",
529
529
  "tensor_parallel_size",
530
530
  "create_runtimes",
531
531
  "device",
@@ -650,6 +650,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
650
650
 
651
651
  super().__setattr__(key, value)
652
652
 
653
+ @deprecate_kwarg(
654
+ old_name="_torch_dtype",
655
+ new_name="dtype",
656
+ version="0.12.0",
657
+ deprecated_type=torch.dtype,
658
+ value_replacer=RBLNCompileConfig.normalize_dtype,
659
+ raise_if_greater_or_equal_version=False,
660
+ )
653
661
  def __init__(
654
662
  self,
655
663
  cls_name: Optional[str] = None,
@@ -661,7 +669,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
661
669
  tensor_parallel_size: Optional[int] = None,
662
670
  timeout: Optional[int] = None,
663
671
  optimum_rbln_version: Optional[str] = None,
664
- _torch_dtype: Optional[str] = None,
672
+ dtype: Optional[Union[str, torch.dtype]] = None,
665
673
  _compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
666
674
  *,
667
675
  optimize_host_memory: Optional[bool] = None,
@@ -680,7 +688,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
680
688
  tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
681
689
  timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
682
690
  optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
683
- _torch_dtype (Optional[str]): The data type to use for the model.
691
+ dtype (Optional[Union[str, torch.dtype]]): The data type to use for the model.
684
692
  _compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
685
693
  kwargs: Additional keyword arguments.
686
694
 
@@ -710,7 +718,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
710
718
  self.npu = npu
711
719
  self.tensor_parallel_size = tensor_parallel_size
712
720
 
713
- self._torch_dtype = _torch_dtype or "float32"
721
+ if dtype is not None and isinstance(dtype, torch.dtype):
722
+ dtype = RBLNCompileConfig.normalize_dtype(dtype)
723
+ self._dtype = dtype or "float32"
714
724
  self.optimum_rbln_version = optimum_rbln_version
715
725
  if self.optimum_rbln_version is None:
716
726
  self.optimum_rbln_version = __version__
@@ -743,14 +753,24 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
743
753
 
744
754
  @property
745
755
  def torch_dtype(self):
746
- return getattr(torch, self._torch_dtype)
756
+ logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
757
+ return self.dtype
747
758
 
748
759
  @torch_dtype.setter
749
760
  def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
750
- if isinstance(torch_dtype, torch.dtype):
751
- torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
761
+ logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
762
+ self.dtype = torch_dtype
752
763
 
753
- self._torch_dtype = torch_dtype
764
+ @property
765
+ def dtype(self):
766
+ return getattr(torch, self._dtype)
767
+
768
+ @dtype.setter
769
+ def dtype(self, dtype: Union[str, torch.dtype]):
770
+ if isinstance(dtype, torch.dtype):
771
+ dtype = RBLNCompileConfig.normalize_dtype(dtype)
772
+
773
+ self._dtype = dtype
754
774
 
755
775
  @property
756
776
  def rbln_model_cls_name(self) -> str:
@@ -774,10 +794,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
774
794
  if isinstance(value, RBLNSerializableConfigProtocol):
775
795
  # Convert nested RBLNModelConfig to its serializable form
776
796
  serializable_map[key] = value._prepare_for_serialization()
797
+ elif key == "_dtype":
798
+ serializable_map["dtype"] = value
799
+ elif isinstance(value, list) and all(isinstance(item, RBLNSerializableConfigProtocol) for item in value):
800
+ serializable_map[key] = [item._prepare_for_serialization() for item in value]
777
801
  elif key == "_compile_cfgs":
778
802
  serializable_map[key] = [cfg.asdict() for cfg in value]
779
803
  else:
780
804
  serializable_map[key] = value
805
+
781
806
  return serializable_map
782
807
 
783
808
  def __repr__(self):
@@ -825,18 +850,12 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
825
850
  if not isinstance(submodule_config, RBLNModelConfig):
826
851
  raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
827
852
 
828
- if not submodule_config.is_frozen():
829
- raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
830
-
831
853
  self._frozen = True
832
854
 
833
855
  def is_frozen(self):
834
856
  return self._frozen
835
857
 
836
858
  def save(self, path: str):
837
- if not self._frozen:
838
- raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
839
-
840
859
  # save as json file without runtime attributes
841
860
  path = Path(path)
842
861
  if path.is_dir():
@@ -90,7 +90,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
90
90
 
91
91
  self.device = torch.device("cpu")
92
92
  self.training = False
93
- self.dtype = rbln_config.torch_dtype
93
+ self.dtype = rbln_config.dtype
94
94
 
95
95
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
96
96
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -223,8 +223,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
223
223
  elif rbln_submodules is None:
224
224
  rbln_submodules = []
225
225
 
226
- rbln_config.freeze()
227
-
228
226
  if config is None:
229
227
  if cls.hf_library_name == "transformers":
230
228
  config = AutoConfig.from_pretrained(
@@ -313,6 +311,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
313
311
  )
314
312
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
315
313
 
314
+ rbln_config.freeze()
315
+
316
316
  return cls(
317
317
  models,
318
318
  config,
@@ -451,15 +451,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
451
451
  model_config: "PretrainedConfig",
452
452
  rbln_config: RBLNModelConfig,
453
453
  ) -> RBLNModelConfig:
454
- rbln_config.torch_dtype = model.dtype
455
- if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
454
+ rbln_config.dtype = model.dtype
455
+ if not cls._supports_non_fp32 and rbln_config.dtype != torch.float32:
456
456
  raise NotImplementedError(
457
457
  f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
458
458
  )
459
459
  rbln_config = cls._update_rbln_config(
460
460
  preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
461
461
  )
462
- rbln_config.freeze()
462
+
463
463
  if rbln_config.rbln_model_cls_name != cls.__name__:
464
464
  raise NameError(
465
465
  f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
@@ -16,4 +16,5 @@ from .attn import *
16
16
  from .flash_attn import *
17
17
  from .kv_cache_update import *
18
18
  from .linear import linear
19
+ from .moe import *
19
20
  from .sliding_window_attn import *
optimum/rbln/ops/attn.py CHANGED
@@ -205,6 +205,7 @@ def paged_causal_attn_decode(
205
205
  block_table: Tensor,
206
206
  block_size: int,
207
207
  mask: Optional[Tensor] = None,
208
+ s_aux: Optional[Tensor] = None,
208
209
  ) -> Tensor:
209
210
  """Defines the computation pattern for fused attention with KV cache updates.
210
211
 
@@ -228,6 +229,7 @@ def paged_causal_attn_decode(
228
229
  - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
229
230
  - block_size: [] - Number of tokens per block
230
231
  - mask: [batch=1, max_seq_len] - attention mask when use position_ids
232
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
231
233
 
232
234
  Returns:
233
235
  Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
@@ -247,6 +249,7 @@ def paged_causal_attn_decode_fake(
247
249
  block_table: Tensor,
248
250
  block_size: int,
249
251
  mask: Optional[Tensor] = None,
252
+ s_aux: Optional[Tensor] = None,
250
253
  ) -> Tensor:
251
254
  return torch.empty_like(q)
252
255
 
@@ -267,6 +270,7 @@ def paged_causal_attn_prefill(
267
270
  block_size: int,
268
271
  is_bidirectional: bool,
269
272
  mask: Optional[Tensor] = None,
273
+ s_aux: Optional[Tensor] = None,
270
274
  ) -> Tensor:
271
275
  """Defines the computation pattern for prefill phase attention with KV cache updates.
272
276
 
@@ -290,6 +294,7 @@ def paged_causal_attn_prefill(
290
294
  - block_size: [] - Number of tokens per block
291
295
  - is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
292
296
  - mask: [batch=1, max_seq_len] - attention mask when use position_ids
297
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
293
298
 
294
299
  Returns:
295
300
  Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
@@ -310,6 +315,7 @@ def paged_causal_attn_prefill_fake(
310
315
  block_size: int,
311
316
  is_bidirectional: bool,
312
317
  mask: Optional[Tensor] = None,
318
+ s_aux: Optional[Tensor] = None,
313
319
  ) -> Tensor:
314
320
  return torch.empty_like(q)
315
321
 
@@ -331,6 +337,7 @@ def paged_causal_attn_decode_kv_fp8(
331
337
  k_scale: Tensor,
332
338
  v_scale: Tensor,
333
339
  mask: Optional[Tensor] = None,
340
+ s_aux: Optional[Tensor] = None,
334
341
  ) -> Tensor:
335
342
  return torch.empty_like(q)
336
343
 
@@ -349,6 +356,7 @@ def paged_causal_attn_decode_kv_fp8_fake(
349
356
  k_scale: Tensor,
350
357
  v_scale: Tensor,
351
358
  mask: Optional[Tensor] = None,
359
+ s_aux: Optional[Tensor] = None,
352
360
  ) -> Tensor:
353
361
  return torch.empty_like(q)
354
362
 
@@ -371,6 +379,7 @@ def paged_causal_attn_prefill_kv_fp8(
371
379
  k_scale: Tensor,
372
380
  v_scale: Tensor,
373
381
  mask: Optional[Tensor] = None,
382
+ s_aux: Optional[Tensor] = None,
374
383
  ) -> Tensor:
375
384
  return torch.empty_like(q)
376
385
 
@@ -390,6 +399,7 @@ def paged_causal_attn_prefill_kv_fp8_fake(
390
399
  k_scale: Tensor,
391
400
  v_scale: Tensor,
392
401
  mask: Optional[Tensor] = None,
402
+ s_aux: Optional[Tensor] = None,
393
403
  ) -> Tensor:
394
404
  return torch.empty_like(q)
395
405
 
@@ -198,6 +198,7 @@ def paged_flash_causal_attn_decode(
198
198
  block_size: int,
199
199
  partition: int,
200
200
  mask: Optional[Tensor] = None,
201
+ s_aux: Optional[Tensor] = None,
201
202
  ) -> Tensor:
202
203
  """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
203
204
 
@@ -219,6 +220,7 @@ def paged_flash_causal_attn_decode_fake(
219
220
  block_size: int,
220
221
  partition: int,
221
222
  mask: Optional[Tensor] = None,
223
+ s_aux: Optional[Tensor] = None,
222
224
  ) -> Tensor:
223
225
  return torch.empty_like(q)
224
226
 
@@ -241,6 +243,7 @@ def paged_flash_causal_attn_decode_kv_fp8(
241
243
  k_scale: Tensor,
242
244
  v_scale: Tensor,
243
245
  mask: Optional[Tensor] = None,
246
+ s_aux: Optional[Tensor] = None,
244
247
  ) -> Tensor:
245
248
  return torch.empty_like(q)
246
249
 
@@ -260,6 +263,7 @@ def paged_flash_causal_attn_decode_kv_fp8_fake(
260
263
  k_scale: Tensor,
261
264
  v_scale: Tensor,
262
265
  mask: Optional[Tensor] = None,
266
+ s_aux: Optional[Tensor] = None,
263
267
  ) -> Tensor:
264
268
  return torch.empty_like(q)
265
269
 
@@ -281,6 +285,7 @@ def paged_flash_causal_attn_prefill(
281
285
  partition: int,
282
286
  is_bidirectional: bool,
283
287
  mask: Optional[Tensor] = None,
288
+ s_aux: Optional[Tensor] = None,
284
289
  ) -> Tensor:
285
290
  """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
286
291
 
@@ -303,6 +308,7 @@ def paged_flash_causal_attn_prefill_fake(
303
308
  partition: int,
304
309
  is_bidirectional: bool,
305
310
  mask: Optional[Tensor] = None,
311
+ s_aux: Optional[Tensor] = None,
306
312
  ) -> Tensor:
307
313
  return torch.empty_like(q)
308
314
 
@@ -326,6 +332,7 @@ def paged_flash_causal_attn_prefill_kv_fp8(
326
332
  k_scale: Tensor,
327
333
  v_scale: Tensor,
328
334
  mask: Optional[Tensor] = None,
335
+ s_aux: Optional[Tensor] = None,
329
336
  ) -> Tensor:
330
337
  return torch.empty_like(q)
331
338
 
@@ -346,5 +353,6 @@ def paged_flash_causal_attn_prefill_kv_fp8_fake(
346
353
  k_scale: Tensor,
347
354
  v_scale: Tensor,
348
355
  mask: Optional[Tensor] = None,
356
+ s_aux: Optional[Tensor] = None,
349
357
  ) -> Tensor:
350
358
  return torch.empty_like(q)
@@ -0,0 +1,180 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op(
22
+ "rbln_custom_ops::custom_moe_glu",
23
+ mutates_args=(),
24
+ )
25
+ def custom_moe_glu(
26
+ hidden_states: Tensor,
27
+ gate_proj_weight: Tensor,
28
+ up_proj_weight: Tensor,
29
+ down_proj_weight: Tensor,
30
+ router_logits: Tensor,
31
+ topk: int,
32
+ norm_topk_prob: bool,
33
+ gate_proj_bias: Optional[Tensor] = None,
34
+ up_proj_bias: Optional[Tensor] = None,
35
+ down_proj_bias: Optional[Tensor] = None,
36
+ ) -> Tensor:
37
+ """
38
+ Customized MoE GLU operation.
39
+
40
+ Expected tensor shapes:
41
+ - hidden_states: [batch*seq_len, hidden_size]
42
+ - gate_proj_weight: [num_experts, hidden_size, intermediate_size]
43
+ - up_proj_weight: [num_experts, hidden_size, intermediate_size]
44
+ - down_proj_weight: [num_experts, intermediate_size, hidden_size]
45
+ - router_logits: [batch*seq_len, num_experts]
46
+ - topk: top k experts to select
47
+ - norm_topk_prob: whether to normalize the top k routing weights with softmax
48
+ - gate_proj_bias: [num_experts, intermediate_size]
49
+ - up_proj_bias: [num_experts, intermediate_size]
50
+ - down_proj_bias: [num_experts, hidden_size]
51
+
52
+ Returns:
53
+ Tensor: [batch * seq_len, hidden_size]
54
+ """
55
+
56
+ return torch.empty_like(hidden_states)
57
+
58
+
59
+ @custom_moe_glu.register_fake
60
+ def custom_moe_glu_fake(
61
+ hidden_states: Tensor,
62
+ gate_proj_weight: Tensor,
63
+ up_proj_weight: Tensor,
64
+ down_proj_weight: Tensor,
65
+ router_logits: Tensor,
66
+ topk: int,
67
+ norm_topk_prob: bool,
68
+ gate_proj_bias: Optional[Tensor] = None,
69
+ up_proj_bias: Optional[Tensor] = None,
70
+ down_proj_bias: Optional[Tensor] = None,
71
+ ) -> Tensor:
72
+ return torch.empty_like(hidden_states)
73
+
74
+
75
+ @torch.library.custom_op(
76
+ "rbln_custom_ops::custom_moe_ff",
77
+ mutates_args=(),
78
+ )
79
+ def custom_moe_ff(
80
+ hidden_states: Tensor,
81
+ gate_proj_weight: Tensor,
82
+ down_proj_weight: Tensor,
83
+ masked_routing_weight: Tensor,
84
+ gate_proj_bias: Optional[Tensor] = None,
85
+ down_proj_bias: Optional[Tensor] = None,
86
+ ) -> Tensor:
87
+ """
88
+ Customized MoE FF operation.
89
+
90
+ Expected tensor shapes:
91
+ - hidden_states: [batch * seq_len, hidden_size]
92
+ - gate_proj_weight: [hidden_size, num_experts * intermediate_size]
93
+ - down_proj_weight: [num_experts * intermediate_size, hidden_size]
94
+ - masked_routing_weight: [batch * seq_len, num_experts]
95
+ - gate_proj_bias: [num_experts * intermediate_size]
96
+ - down_proj_bias: [hidden_size]
97
+
98
+ Returns:
99
+ Tensor: [batch * seq_len, hidden_size]
100
+ """
101
+ return torch.empty_like(hidden_states)
102
+
103
+
104
+ @custom_moe_ff.register_fake
105
+ def custom_moe_ff_fake(
106
+ hidden_states: Tensor,
107
+ gate_proj_weight: Tensor,
108
+ down_proj_weight: Tensor,
109
+ masked_routing_weight: Tensor,
110
+ gate_proj_bias: Optional[Tensor] = None,
111
+ down_proj_bias: Optional[Tensor] = None,
112
+ ) -> Tensor:
113
+ return torch.empty_like(hidden_states)
114
+
115
+
116
+ @torch.library.custom_op(
117
+ "rbln_custom_ops::custom_moe_glu_mxfp4",
118
+ mutates_args=(),
119
+ )
120
+ def custom_moe_glu_mxfp4(
121
+ hidden_states: Tensor,
122
+ gate_proj_blocks: Tensor,
123
+ gate_proj_scales: Tensor,
124
+ gate_proj_bias: Tensor,
125
+ up_proj_blocks: Tensor,
126
+ up_proj_scales: Tensor,
127
+ up_proj_bias: Tensor,
128
+ down_proj_blocks: Tensor,
129
+ down_proj_scales: Tensor,
130
+ down_proj_bias: Tensor,
131
+ router_logits: Tensor,
132
+ alpha: Tensor,
133
+ limit: Tensor,
134
+ k: int,
135
+ post_norm: bool,
136
+ ) -> Tensor:
137
+ """
138
+ Customized MoE GLU operation.
139
+
140
+ Expected tensor shapes:
141
+ - hidden_states: [batch*seq_len, hidden_size]
142
+ - gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
143
+ - gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
144
+ - gate_proj_bias: [num_experts, intermediate_size]
145
+ - up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
146
+ - up_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
147
+ - up_proj_bias: [num_experts, intermediate_size]
148
+ - down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2]
149
+ - down_proj_scales: [num_experts, hidden_size, intermediate_size // 32]
150
+ - masked_routing_weight: [batch * seq_len, num_experts]
151
+ - expert_select_count: [num_experts]
152
+ - alpha: []
153
+ - limit: []
154
+
155
+ Returns:
156
+ Tensor: [batch * seq_len, hidden_size]
157
+ """
158
+
159
+ return torch.empty_like(hidden_states)
160
+
161
+
162
+ @custom_moe_glu_mxfp4.register_fake
163
+ def custom_moe_glu_mxfp4_fake(
164
+ hidden_states: Tensor,
165
+ gate_proj_blocks: Tensor,
166
+ gate_proj_scales: Tensor,
167
+ gate_proj_bias: Tensor,
168
+ up_proj_blocks: Tensor,
169
+ up_proj_scales: Tensor,
170
+ up_proj_bias: Tensor,
171
+ down_proj_blocks: Tensor,
172
+ down_proj_scales: Tensor,
173
+ down_proj_bias: Tensor,
174
+ router_logits: Tensor,
175
+ alpha: Tensor,
176
+ limit: Tensor,
177
+ k: int,
178
+ post_norm: bool,
179
+ ) -> Tensor:
180
+ return torch.empty_like(hidden_states)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
17
19
  from torch import Tensor
18
20
 
@@ -33,6 +35,7 @@ def paged_sliding_window_attn_prefill(
33
35
  block_table: Tensor,
34
36
  block_size: int,
35
37
  is_bidirectional: bool,
38
+ s_aux: Optional[Tensor] = None,
36
39
  ) -> Tensor:
37
40
  """Defines the computation pattern for prefill phase attention with KV cache updates.
38
41
 
@@ -53,6 +56,7 @@ def paged_sliding_window_attn_prefill(
53
56
  - cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
54
57
  - scale: [] - Attention scale factor
55
58
  - is_bidirectional: [] - Whether the attention is bidirectional
59
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
56
60
  Returns:
57
61
  Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
58
62
  """
@@ -72,6 +76,7 @@ def paged_sliding_window_attn_prefill_fake(
72
76
  block_table: Tensor,
73
77
  block_size: int,
74
78
  is_bidirectional: bool,
79
+ s_aux: Optional[Tensor] = None,
75
80
  ) -> Tensor:
76
81
  return torch.empty_like(q)
77
82
 
@@ -91,6 +96,8 @@ def paged_sliding_window_attn_decode(
91
96
  scale: Tensor,
92
97
  block_table: Tensor,
93
98
  block_size: int,
99
+ attn_mask: Tensor,
100
+ s_aux: Optional[Tensor] = None,
94
101
  ) -> Tensor:
95
102
  return torch.empty_like(q)
96
103
 
@@ -107,5 +114,7 @@ def paged_sliding_window_attn_decode_fake(
107
114
  scale: Tensor,
108
115
  block_table: Tensor,
109
116
  block_size: int,
117
+ attn_mask: Tensor,
118
+ s_aux: Optional[Tensor] = None,
110
119
  ) -> Tensor:
111
120
  return torch.empty_like(q)