optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -91,6 +91,10 @@ _import_structure = {
|
|
|
91
91
|
"RBLNGemmaModel",
|
|
92
92
|
"RBLNGemmaModelConfig",
|
|
93
93
|
"RBLNGemmaForCausalLM",
|
|
94
|
+
"RBLNGemma2ForCausalLM",
|
|
95
|
+
"RBLNGemma2ForCausalLMConfig",
|
|
96
|
+
"RBLNGemma2Model",
|
|
97
|
+
"RBLNGemma2ModelConfig",
|
|
94
98
|
"RBLNGemmaForCausalLMConfig",
|
|
95
99
|
"RBLNGemma3ForCausalLM",
|
|
96
100
|
"RBLNGemma3ForCausalLMConfig",
|
|
@@ -100,6 +104,8 @@ _import_structure = {
|
|
|
100
104
|
"RBLNGPT2ModelConfig",
|
|
101
105
|
"RBLNGPT2LMHeadModel",
|
|
102
106
|
"RBLNGPT2LMHeadModelConfig",
|
|
107
|
+
"RBLNGptOssForCausalLM",
|
|
108
|
+
"RBLNGptOssForCausalLMConfig",
|
|
103
109
|
"RBLNGroundingDinoDecoder",
|
|
104
110
|
"RBLNGroundingDinoDecoderConfig",
|
|
105
111
|
"RBLNGroundingDinoForObjectDetection",
|
|
@@ -140,14 +146,24 @@ _import_structure = {
|
|
|
140
146
|
"RBLNPixtralVisionModelConfig",
|
|
141
147
|
"RBLNPhiModel",
|
|
142
148
|
"RBLNPhiModelConfig",
|
|
149
|
+
"RBLNPaliGemmaForConditionalGeneration",
|
|
150
|
+
"RBLNPaliGemmaForConditionalGenerationConfig",
|
|
151
|
+
"RBLNPaliGemmaModel",
|
|
152
|
+
"RBLNPaliGemmaModelConfig",
|
|
143
153
|
"RBLNQwen2ForCausalLM",
|
|
144
154
|
"RBLNQwen2ForCausalLMConfig",
|
|
145
155
|
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
|
146
156
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
147
157
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
148
158
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
159
|
+
"RBLNQwen3MoeForCausalLM",
|
|
160
|
+
"RBLNQwen3MoeForCausalLMConfig",
|
|
161
|
+
"RBLNQwen2_5_VLModel",
|
|
162
|
+
"RBLNQwen2_5_VLModelConfig",
|
|
149
163
|
"RBLNQwen2Model",
|
|
150
164
|
"RBLNQwen2ModelConfig",
|
|
165
|
+
"RBLNQwen2MoeForCausalLM",
|
|
166
|
+
"RBLNQwen2MoeForCausalLMConfig",
|
|
151
167
|
"RBLNQwen3ForCausalLM",
|
|
152
168
|
"RBLNQwen3ForCausalLMConfig",
|
|
153
169
|
"RBLNQwen3Model",
|
|
@@ -156,6 +172,8 @@ _import_structure = {
|
|
|
156
172
|
"RBLNQwen2VisionTransformerPretrainedModelConfig",
|
|
157
173
|
"RBLNQwen2VLForConditionalGeneration",
|
|
158
174
|
"RBLNQwen2VLForConditionalGenerationConfig",
|
|
175
|
+
"RBLNQwen2VLModel",
|
|
176
|
+
"RBLNQwen2VLModelConfig",
|
|
159
177
|
"RBLNResNetForImageClassification",
|
|
160
178
|
"RBLNResNetForImageClassificationConfig",
|
|
161
179
|
"RBLNRobertaForMaskedLM",
|
|
@@ -394,6 +412,10 @@ if TYPE_CHECKING:
|
|
|
394
412
|
RBLNDPTForDepthEstimationConfig,
|
|
395
413
|
RBLNExaoneForCausalLM,
|
|
396
414
|
RBLNExaoneForCausalLMConfig,
|
|
415
|
+
RBLNGemma2ForCausalLM,
|
|
416
|
+
RBLNGemma2ForCausalLMConfig,
|
|
417
|
+
RBLNGemma2Model,
|
|
418
|
+
RBLNGemma2ModelConfig,
|
|
397
419
|
RBLNGemma3ForCausalLM,
|
|
398
420
|
RBLNGemma3ForCausalLMConfig,
|
|
399
421
|
RBLNGemma3ForConditionalGeneration,
|
|
@@ -406,6 +428,8 @@ if TYPE_CHECKING:
|
|
|
406
428
|
RBLNGPT2LMHeadModelConfig,
|
|
407
429
|
RBLNGPT2Model,
|
|
408
430
|
RBLNGPT2ModelConfig,
|
|
431
|
+
RBLNGptOssForCausalLM,
|
|
432
|
+
RBLNGptOssForCausalLMConfig,
|
|
409
433
|
RBLNGroundingDinoDecoder,
|
|
410
434
|
RBLNGroundingDinoDecoderConfig,
|
|
411
435
|
RBLNGroundingDinoEncoder,
|
|
@@ -436,6 +460,10 @@ if TYPE_CHECKING:
|
|
|
436
460
|
RBLNOPTForCausalLMConfig,
|
|
437
461
|
RBLNOPTModel,
|
|
438
462
|
RBLNOPTModelConfig,
|
|
463
|
+
RBLNPaliGemmaForConditionalGeneration,
|
|
464
|
+
RBLNPaliGemmaForConditionalGenerationConfig,
|
|
465
|
+
RBLNPaliGemmaModel,
|
|
466
|
+
RBLNPaliGemmaModelConfig,
|
|
439
467
|
RBLNPegasusForConditionalGeneration,
|
|
440
468
|
RBLNPegasusForConditionalGenerationConfig,
|
|
441
469
|
RBLNPegasusModel,
|
|
@@ -450,18 +478,26 @@ if TYPE_CHECKING:
|
|
|
450
478
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
451
479
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
452
480
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
481
|
+
RBLNQwen2_5_VLModel,
|
|
482
|
+
RBLNQwen2_5_VLModelConfig,
|
|
453
483
|
RBLNQwen2ForCausalLM,
|
|
454
484
|
RBLNQwen2ForCausalLMConfig,
|
|
455
485
|
RBLNQwen2Model,
|
|
456
486
|
RBLNQwen2ModelConfig,
|
|
487
|
+
RBLNQwen2MoeForCausalLM,
|
|
488
|
+
RBLNQwen2MoeForCausalLMConfig,
|
|
457
489
|
RBLNQwen2VisionTransformerPretrainedModel,
|
|
458
490
|
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
459
491
|
RBLNQwen2VLForConditionalGeneration,
|
|
460
492
|
RBLNQwen2VLForConditionalGenerationConfig,
|
|
493
|
+
RBLNQwen2VLModel,
|
|
494
|
+
RBLNQwen2VLModelConfig,
|
|
461
495
|
RBLNQwen3ForCausalLM,
|
|
462
496
|
RBLNQwen3ForCausalLMConfig,
|
|
463
497
|
RBLNQwen3Model,
|
|
464
498
|
RBLNQwen3ModelConfig,
|
|
499
|
+
RBLNQwen3MoeForCausalLM,
|
|
500
|
+
RBLNQwen3MoeForCausalLMConfig,
|
|
465
501
|
RBLNResNetForImageClassification,
|
|
466
502
|
RBLNResNetForImageClassificationConfig,
|
|
467
503
|
RBLNRobertaForMaskedLM,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.9.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 9,
|
|
31
|
+
__version__ = version = '0.9.5a4'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 9, 5, 'a4')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from packaging.version import Version
|
|
25
25
|
|
|
26
26
|
from .__version__ import __version__
|
|
27
|
-
from .utils.deprecation import warn_deprecated_npu
|
|
27
|
+
from .utils.deprecation import deprecate_kwarg, warn_deprecated_npu
|
|
28
28
|
from .utils.logging import get_logger
|
|
29
29
|
from .utils.runtime_utils import ContextRblnConfig
|
|
30
30
|
|
|
@@ -92,7 +92,7 @@ class RBLNCompileConfig:
|
|
|
92
92
|
and isinstance(item[0], str) # name
|
|
93
93
|
and isinstance(item[1], (tuple, list)) # shape
|
|
94
94
|
and all(isinstance(x, int) for x in item[1])
|
|
95
|
-
and isinstance(item[2], str) # dtype
|
|
95
|
+
and (isinstance(item[2], str) or isinstance(item[2], torch.dtype)) # dtype
|
|
96
96
|
for item in input_info
|
|
97
97
|
)
|
|
98
98
|
|
|
@@ -524,8 +524,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
524
524
|
non_save_attributes = [
|
|
525
525
|
"_frozen",
|
|
526
526
|
"_runtime_options",
|
|
527
|
-
"torch_dtype",
|
|
528
527
|
"npu",
|
|
528
|
+
"dtype",
|
|
529
529
|
"tensor_parallel_size",
|
|
530
530
|
"create_runtimes",
|
|
531
531
|
"device",
|
|
@@ -650,6 +650,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
650
650
|
|
|
651
651
|
super().__setattr__(key, value)
|
|
652
652
|
|
|
653
|
+
@deprecate_kwarg(
|
|
654
|
+
old_name="_torch_dtype",
|
|
655
|
+
new_name="dtype",
|
|
656
|
+
version="0.12.0",
|
|
657
|
+
deprecated_type=torch.dtype,
|
|
658
|
+
value_replacer=RBLNCompileConfig.normalize_dtype,
|
|
659
|
+
raise_if_greater_or_equal_version=False,
|
|
660
|
+
)
|
|
653
661
|
def __init__(
|
|
654
662
|
self,
|
|
655
663
|
cls_name: Optional[str] = None,
|
|
@@ -661,7 +669,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
661
669
|
tensor_parallel_size: Optional[int] = None,
|
|
662
670
|
timeout: Optional[int] = None,
|
|
663
671
|
optimum_rbln_version: Optional[str] = None,
|
|
664
|
-
|
|
672
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
|
665
673
|
_compile_cfgs: Optional[List[RBLNCompileConfig]] = None,
|
|
666
674
|
*,
|
|
667
675
|
optimize_host_memory: Optional[bool] = None,
|
|
@@ -680,7 +688,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
680
688
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
|
|
681
689
|
timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
|
|
682
690
|
optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
|
|
683
|
-
|
|
691
|
+
dtype (Optional[Union[str, torch.dtype]]): The data type to use for the model.
|
|
684
692
|
_compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
|
|
685
693
|
kwargs: Additional keyword arguments.
|
|
686
694
|
|
|
@@ -710,7 +718,9 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
710
718
|
self.npu = npu
|
|
711
719
|
self.tensor_parallel_size = tensor_parallel_size
|
|
712
720
|
|
|
713
|
-
|
|
721
|
+
if dtype is not None and isinstance(dtype, torch.dtype):
|
|
722
|
+
dtype = RBLNCompileConfig.normalize_dtype(dtype)
|
|
723
|
+
self._dtype = dtype or "float32"
|
|
714
724
|
self.optimum_rbln_version = optimum_rbln_version
|
|
715
725
|
if self.optimum_rbln_version is None:
|
|
716
726
|
self.optimum_rbln_version = __version__
|
|
@@ -743,14 +753,24 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
743
753
|
|
|
744
754
|
@property
|
|
745
755
|
def torch_dtype(self):
|
|
746
|
-
|
|
756
|
+
logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
|
|
757
|
+
return self.dtype
|
|
747
758
|
|
|
748
759
|
@torch_dtype.setter
|
|
749
760
|
def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
|
|
750
|
-
|
|
751
|
-
|
|
761
|
+
logger.warning_once("`torch_dtype` is deprecated. Use `dtype` instead.")
|
|
762
|
+
self.dtype = torch_dtype
|
|
752
763
|
|
|
753
|
-
|
|
764
|
+
@property
|
|
765
|
+
def dtype(self):
|
|
766
|
+
return getattr(torch, self._dtype)
|
|
767
|
+
|
|
768
|
+
@dtype.setter
|
|
769
|
+
def dtype(self, dtype: Union[str, torch.dtype]):
|
|
770
|
+
if isinstance(dtype, torch.dtype):
|
|
771
|
+
dtype = RBLNCompileConfig.normalize_dtype(dtype)
|
|
772
|
+
|
|
773
|
+
self._dtype = dtype
|
|
754
774
|
|
|
755
775
|
@property
|
|
756
776
|
def rbln_model_cls_name(self) -> str:
|
|
@@ -774,10 +794,15 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
774
794
|
if isinstance(value, RBLNSerializableConfigProtocol):
|
|
775
795
|
# Convert nested RBLNModelConfig to its serializable form
|
|
776
796
|
serializable_map[key] = value._prepare_for_serialization()
|
|
797
|
+
elif key == "_dtype":
|
|
798
|
+
serializable_map["dtype"] = value
|
|
799
|
+
elif isinstance(value, list) and all(isinstance(item, RBLNSerializableConfigProtocol) for item in value):
|
|
800
|
+
serializable_map[key] = [item._prepare_for_serialization() for item in value]
|
|
777
801
|
elif key == "_compile_cfgs":
|
|
778
802
|
serializable_map[key] = [cfg.asdict() for cfg in value]
|
|
779
803
|
else:
|
|
780
804
|
serializable_map[key] = value
|
|
805
|
+
|
|
781
806
|
return serializable_map
|
|
782
807
|
|
|
783
808
|
def __repr__(self):
|
|
@@ -825,18 +850,12 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
825
850
|
if not isinstance(submodule_config, RBLNModelConfig):
|
|
826
851
|
raise ValueError(f"`{submodule_name}` must be an instance of `RBLNModelConfig` before freezing.")
|
|
827
852
|
|
|
828
|
-
if not submodule_config.is_frozen():
|
|
829
|
-
raise ValueError(f"`{submodule_name}` config must be frozen before freezing super config.")
|
|
830
|
-
|
|
831
853
|
self._frozen = True
|
|
832
854
|
|
|
833
855
|
def is_frozen(self):
|
|
834
856
|
return self._frozen
|
|
835
857
|
|
|
836
858
|
def save(self, path: str):
|
|
837
|
-
if not self._frozen:
|
|
838
|
-
raise RuntimeError("`RBLNModelConfig` is not frozen. Please call `set_compile_cfgs` first.")
|
|
839
|
-
|
|
840
859
|
# save as json file without runtime attributes
|
|
841
860
|
path = Path(path)
|
|
842
861
|
if path.is_dir():
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -90,7 +90,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
90
90
|
|
|
91
91
|
self.device = torch.device("cpu")
|
|
92
92
|
self.training = False
|
|
93
|
-
self.dtype = rbln_config.
|
|
93
|
+
self.dtype = rbln_config.dtype
|
|
94
94
|
|
|
95
95
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
96
96
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -223,8 +223,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
223
223
|
elif rbln_submodules is None:
|
|
224
224
|
rbln_submodules = []
|
|
225
225
|
|
|
226
|
-
rbln_config.freeze()
|
|
227
|
-
|
|
228
226
|
if config is None:
|
|
229
227
|
if cls.hf_library_name == "transformers":
|
|
230
228
|
config = AutoConfig.from_pretrained(
|
|
@@ -313,6 +311,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
313
311
|
)
|
|
314
312
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
315
313
|
|
|
314
|
+
rbln_config.freeze()
|
|
315
|
+
|
|
316
316
|
return cls(
|
|
317
317
|
models,
|
|
318
318
|
config,
|
|
@@ -451,15 +451,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
451
451
|
model_config: "PretrainedConfig",
|
|
452
452
|
rbln_config: RBLNModelConfig,
|
|
453
453
|
) -> RBLNModelConfig:
|
|
454
|
-
rbln_config.
|
|
455
|
-
if not cls._supports_non_fp32 and rbln_config.
|
|
454
|
+
rbln_config.dtype = model.dtype
|
|
455
|
+
if not cls._supports_non_fp32 and rbln_config.dtype != torch.float32:
|
|
456
456
|
raise NotImplementedError(
|
|
457
457
|
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
458
458
|
)
|
|
459
459
|
rbln_config = cls._update_rbln_config(
|
|
460
460
|
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
461
461
|
)
|
|
462
|
-
|
|
462
|
+
|
|
463
463
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
464
464
|
raise NameError(
|
|
465
465
|
f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
|
optimum/rbln/ops/__init__.py
CHANGED
optimum/rbln/ops/attn.py
CHANGED
|
@@ -205,6 +205,7 @@ def paged_causal_attn_decode(
|
|
|
205
205
|
block_table: Tensor,
|
|
206
206
|
block_size: int,
|
|
207
207
|
mask: Optional[Tensor] = None,
|
|
208
|
+
s_aux: Optional[Tensor] = None,
|
|
208
209
|
) -> Tensor:
|
|
209
210
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
|
210
211
|
|
|
@@ -228,6 +229,7 @@ def paged_causal_attn_decode(
|
|
|
228
229
|
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
|
229
230
|
- block_size: [] - Number of tokens per block
|
|
230
231
|
- mask: [batch=1, max_seq_len] - attention mask when use position_ids
|
|
232
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
231
233
|
|
|
232
234
|
Returns:
|
|
233
235
|
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
|
@@ -247,6 +249,7 @@ def paged_causal_attn_decode_fake(
|
|
|
247
249
|
block_table: Tensor,
|
|
248
250
|
block_size: int,
|
|
249
251
|
mask: Optional[Tensor] = None,
|
|
252
|
+
s_aux: Optional[Tensor] = None,
|
|
250
253
|
) -> Tensor:
|
|
251
254
|
return torch.empty_like(q)
|
|
252
255
|
|
|
@@ -267,6 +270,7 @@ def paged_causal_attn_prefill(
|
|
|
267
270
|
block_size: int,
|
|
268
271
|
is_bidirectional: bool,
|
|
269
272
|
mask: Optional[Tensor] = None,
|
|
273
|
+
s_aux: Optional[Tensor] = None,
|
|
270
274
|
) -> Tensor:
|
|
271
275
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
272
276
|
|
|
@@ -290,6 +294,7 @@ def paged_causal_attn_prefill(
|
|
|
290
294
|
- block_size: [] - Number of tokens per block
|
|
291
295
|
- is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
|
|
292
296
|
- mask: [batch=1, max_seq_len] - attention mask when use position_ids
|
|
297
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
293
298
|
|
|
294
299
|
Returns:
|
|
295
300
|
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
@@ -310,6 +315,7 @@ def paged_causal_attn_prefill_fake(
|
|
|
310
315
|
block_size: int,
|
|
311
316
|
is_bidirectional: bool,
|
|
312
317
|
mask: Optional[Tensor] = None,
|
|
318
|
+
s_aux: Optional[Tensor] = None,
|
|
313
319
|
) -> Tensor:
|
|
314
320
|
return torch.empty_like(q)
|
|
315
321
|
|
|
@@ -331,6 +337,7 @@ def paged_causal_attn_decode_kv_fp8(
|
|
|
331
337
|
k_scale: Tensor,
|
|
332
338
|
v_scale: Tensor,
|
|
333
339
|
mask: Optional[Tensor] = None,
|
|
340
|
+
s_aux: Optional[Tensor] = None,
|
|
334
341
|
) -> Tensor:
|
|
335
342
|
return torch.empty_like(q)
|
|
336
343
|
|
|
@@ -349,6 +356,7 @@ def paged_causal_attn_decode_kv_fp8_fake(
|
|
|
349
356
|
k_scale: Tensor,
|
|
350
357
|
v_scale: Tensor,
|
|
351
358
|
mask: Optional[Tensor] = None,
|
|
359
|
+
s_aux: Optional[Tensor] = None,
|
|
352
360
|
) -> Tensor:
|
|
353
361
|
return torch.empty_like(q)
|
|
354
362
|
|
|
@@ -371,6 +379,7 @@ def paged_causal_attn_prefill_kv_fp8(
|
|
|
371
379
|
k_scale: Tensor,
|
|
372
380
|
v_scale: Tensor,
|
|
373
381
|
mask: Optional[Tensor] = None,
|
|
382
|
+
s_aux: Optional[Tensor] = None,
|
|
374
383
|
) -> Tensor:
|
|
375
384
|
return torch.empty_like(q)
|
|
376
385
|
|
|
@@ -390,6 +399,7 @@ def paged_causal_attn_prefill_kv_fp8_fake(
|
|
|
390
399
|
k_scale: Tensor,
|
|
391
400
|
v_scale: Tensor,
|
|
392
401
|
mask: Optional[Tensor] = None,
|
|
402
|
+
s_aux: Optional[Tensor] = None,
|
|
393
403
|
) -> Tensor:
|
|
394
404
|
return torch.empty_like(q)
|
|
395
405
|
|
optimum/rbln/ops/flash_attn.py
CHANGED
|
@@ -198,6 +198,7 @@ def paged_flash_causal_attn_decode(
|
|
|
198
198
|
block_size: int,
|
|
199
199
|
partition: int,
|
|
200
200
|
mask: Optional[Tensor] = None,
|
|
201
|
+
s_aux: Optional[Tensor] = None,
|
|
201
202
|
) -> Tensor:
|
|
202
203
|
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
|
203
204
|
|
|
@@ -219,6 +220,7 @@ def paged_flash_causal_attn_decode_fake(
|
|
|
219
220
|
block_size: int,
|
|
220
221
|
partition: int,
|
|
221
222
|
mask: Optional[Tensor] = None,
|
|
223
|
+
s_aux: Optional[Tensor] = None,
|
|
222
224
|
) -> Tensor:
|
|
223
225
|
return torch.empty_like(q)
|
|
224
226
|
|
|
@@ -241,6 +243,7 @@ def paged_flash_causal_attn_decode_kv_fp8(
|
|
|
241
243
|
k_scale: Tensor,
|
|
242
244
|
v_scale: Tensor,
|
|
243
245
|
mask: Optional[Tensor] = None,
|
|
246
|
+
s_aux: Optional[Tensor] = None,
|
|
244
247
|
) -> Tensor:
|
|
245
248
|
return torch.empty_like(q)
|
|
246
249
|
|
|
@@ -260,6 +263,7 @@ def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
|
260
263
|
k_scale: Tensor,
|
|
261
264
|
v_scale: Tensor,
|
|
262
265
|
mask: Optional[Tensor] = None,
|
|
266
|
+
s_aux: Optional[Tensor] = None,
|
|
263
267
|
) -> Tensor:
|
|
264
268
|
return torch.empty_like(q)
|
|
265
269
|
|
|
@@ -281,6 +285,7 @@ def paged_flash_causal_attn_prefill(
|
|
|
281
285
|
partition: int,
|
|
282
286
|
is_bidirectional: bool,
|
|
283
287
|
mask: Optional[Tensor] = None,
|
|
288
|
+
s_aux: Optional[Tensor] = None,
|
|
284
289
|
) -> Tensor:
|
|
285
290
|
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
|
286
291
|
|
|
@@ -303,6 +308,7 @@ def paged_flash_causal_attn_prefill_fake(
|
|
|
303
308
|
partition: int,
|
|
304
309
|
is_bidirectional: bool,
|
|
305
310
|
mask: Optional[Tensor] = None,
|
|
311
|
+
s_aux: Optional[Tensor] = None,
|
|
306
312
|
) -> Tensor:
|
|
307
313
|
return torch.empty_like(q)
|
|
308
314
|
|
|
@@ -326,6 +332,7 @@ def paged_flash_causal_attn_prefill_kv_fp8(
|
|
|
326
332
|
k_scale: Tensor,
|
|
327
333
|
v_scale: Tensor,
|
|
328
334
|
mask: Optional[Tensor] = None,
|
|
335
|
+
s_aux: Optional[Tensor] = None,
|
|
329
336
|
) -> Tensor:
|
|
330
337
|
return torch.empty_like(q)
|
|
331
338
|
|
|
@@ -346,5 +353,6 @@ def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
|
346
353
|
k_scale: Tensor,
|
|
347
354
|
v_scale: Tensor,
|
|
348
355
|
mask: Optional[Tensor] = None,
|
|
356
|
+
s_aux: Optional[Tensor] = None,
|
|
349
357
|
) -> Tensor:
|
|
350
358
|
return torch.empty_like(q)
|
optimum/rbln/ops/moe.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch.library.custom_op(
|
|
22
|
+
"rbln_custom_ops::custom_moe_glu",
|
|
23
|
+
mutates_args=(),
|
|
24
|
+
)
|
|
25
|
+
def custom_moe_glu(
|
|
26
|
+
hidden_states: Tensor,
|
|
27
|
+
gate_proj_weight: Tensor,
|
|
28
|
+
up_proj_weight: Tensor,
|
|
29
|
+
down_proj_weight: Tensor,
|
|
30
|
+
router_logits: Tensor,
|
|
31
|
+
topk: int,
|
|
32
|
+
norm_topk_prob: bool,
|
|
33
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
34
|
+
up_proj_bias: Optional[Tensor] = None,
|
|
35
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
36
|
+
) -> Tensor:
|
|
37
|
+
"""
|
|
38
|
+
Customized MoE GLU operation.
|
|
39
|
+
|
|
40
|
+
Expected tensor shapes:
|
|
41
|
+
- hidden_states: [batch*seq_len, hidden_size]
|
|
42
|
+
- gate_proj_weight: [num_experts, hidden_size, intermediate_size]
|
|
43
|
+
- up_proj_weight: [num_experts, hidden_size, intermediate_size]
|
|
44
|
+
- down_proj_weight: [num_experts, intermediate_size, hidden_size]
|
|
45
|
+
- router_logits: [batch*seq_len, num_experts]
|
|
46
|
+
- topk: top k experts to select
|
|
47
|
+
- norm_topk_prob: whether to normalize the top k routing weights with softmax
|
|
48
|
+
- gate_proj_bias: [num_experts, intermediate_size]
|
|
49
|
+
- up_proj_bias: [num_experts, intermediate_size]
|
|
50
|
+
- down_proj_bias: [num_experts, hidden_size]
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
return torch.empty_like(hidden_states)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@custom_moe_glu.register_fake
|
|
60
|
+
def custom_moe_glu_fake(
|
|
61
|
+
hidden_states: Tensor,
|
|
62
|
+
gate_proj_weight: Tensor,
|
|
63
|
+
up_proj_weight: Tensor,
|
|
64
|
+
down_proj_weight: Tensor,
|
|
65
|
+
router_logits: Tensor,
|
|
66
|
+
topk: int,
|
|
67
|
+
norm_topk_prob: bool,
|
|
68
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
69
|
+
up_proj_bias: Optional[Tensor] = None,
|
|
70
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
71
|
+
) -> Tensor:
|
|
72
|
+
return torch.empty_like(hidden_states)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@torch.library.custom_op(
|
|
76
|
+
"rbln_custom_ops::custom_moe_ff",
|
|
77
|
+
mutates_args=(),
|
|
78
|
+
)
|
|
79
|
+
def custom_moe_ff(
|
|
80
|
+
hidden_states: Tensor,
|
|
81
|
+
gate_proj_weight: Tensor,
|
|
82
|
+
down_proj_weight: Tensor,
|
|
83
|
+
masked_routing_weight: Tensor,
|
|
84
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
85
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
86
|
+
) -> Tensor:
|
|
87
|
+
"""
|
|
88
|
+
Customized MoE FF operation.
|
|
89
|
+
|
|
90
|
+
Expected tensor shapes:
|
|
91
|
+
- hidden_states: [batch * seq_len, hidden_size]
|
|
92
|
+
- gate_proj_weight: [hidden_size, num_experts * intermediate_size]
|
|
93
|
+
- down_proj_weight: [num_experts * intermediate_size, hidden_size]
|
|
94
|
+
- masked_routing_weight: [batch * seq_len, num_experts]
|
|
95
|
+
- gate_proj_bias: [num_experts * intermediate_size]
|
|
96
|
+
- down_proj_bias: [hidden_size]
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
100
|
+
"""
|
|
101
|
+
return torch.empty_like(hidden_states)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@custom_moe_ff.register_fake
|
|
105
|
+
def custom_moe_ff_fake(
|
|
106
|
+
hidden_states: Tensor,
|
|
107
|
+
gate_proj_weight: Tensor,
|
|
108
|
+
down_proj_weight: Tensor,
|
|
109
|
+
masked_routing_weight: Tensor,
|
|
110
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
111
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
112
|
+
) -> Tensor:
|
|
113
|
+
return torch.empty_like(hidden_states)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@torch.library.custom_op(
|
|
117
|
+
"rbln_custom_ops::custom_moe_glu_mxfp4",
|
|
118
|
+
mutates_args=(),
|
|
119
|
+
)
|
|
120
|
+
def custom_moe_glu_mxfp4(
|
|
121
|
+
hidden_states: Tensor,
|
|
122
|
+
gate_proj_blocks: Tensor,
|
|
123
|
+
gate_proj_scales: Tensor,
|
|
124
|
+
gate_proj_bias: Tensor,
|
|
125
|
+
up_proj_blocks: Tensor,
|
|
126
|
+
up_proj_scales: Tensor,
|
|
127
|
+
up_proj_bias: Tensor,
|
|
128
|
+
down_proj_blocks: Tensor,
|
|
129
|
+
down_proj_scales: Tensor,
|
|
130
|
+
down_proj_bias: Tensor,
|
|
131
|
+
router_logits: Tensor,
|
|
132
|
+
alpha: Tensor,
|
|
133
|
+
limit: Tensor,
|
|
134
|
+
k: int,
|
|
135
|
+
post_norm: bool,
|
|
136
|
+
) -> Tensor:
|
|
137
|
+
"""
|
|
138
|
+
Customized MoE GLU operation.
|
|
139
|
+
|
|
140
|
+
Expected tensor shapes:
|
|
141
|
+
- hidden_states: [batch*seq_len, hidden_size]
|
|
142
|
+
- gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
|
|
143
|
+
- gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
|
|
144
|
+
- gate_proj_bias: [num_experts, intermediate_size]
|
|
145
|
+
- up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
|
|
146
|
+
- up_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
|
|
147
|
+
- up_proj_bias: [num_experts, intermediate_size]
|
|
148
|
+
- down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2]
|
|
149
|
+
- down_proj_scales: [num_experts, hidden_size, intermediate_size // 32]
|
|
150
|
+
- masked_routing_weight: [batch * seq_len, num_experts]
|
|
151
|
+
- expert_select_count: [num_experts]
|
|
152
|
+
- alpha: []
|
|
153
|
+
- limit: []
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
return torch.empty_like(hidden_states)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@custom_moe_glu_mxfp4.register_fake
|
|
163
|
+
def custom_moe_glu_mxfp4_fake(
|
|
164
|
+
hidden_states: Tensor,
|
|
165
|
+
gate_proj_blocks: Tensor,
|
|
166
|
+
gate_proj_scales: Tensor,
|
|
167
|
+
gate_proj_bias: Tensor,
|
|
168
|
+
up_proj_blocks: Tensor,
|
|
169
|
+
up_proj_scales: Tensor,
|
|
170
|
+
up_proj_bias: Tensor,
|
|
171
|
+
down_proj_blocks: Tensor,
|
|
172
|
+
down_proj_scales: Tensor,
|
|
173
|
+
down_proj_bias: Tensor,
|
|
174
|
+
router_logits: Tensor,
|
|
175
|
+
alpha: Tensor,
|
|
176
|
+
limit: Tensor,
|
|
177
|
+
k: int,
|
|
178
|
+
post_norm: bool,
|
|
179
|
+
) -> Tensor:
|
|
180
|
+
return torch.empty_like(hidden_states)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
import torch
|
|
17
19
|
from torch import Tensor
|
|
18
20
|
|
|
@@ -33,6 +35,7 @@ def paged_sliding_window_attn_prefill(
|
|
|
33
35
|
block_table: Tensor,
|
|
34
36
|
block_size: int,
|
|
35
37
|
is_bidirectional: bool,
|
|
38
|
+
s_aux: Optional[Tensor] = None,
|
|
36
39
|
) -> Tensor:
|
|
37
40
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
38
41
|
|
|
@@ -53,6 +56,7 @@ def paged_sliding_window_attn_prefill(
|
|
|
53
56
|
- cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
|
|
54
57
|
- scale: [] - Attention scale factor
|
|
55
58
|
- is_bidirectional: [] - Whether the attention is bidirectional
|
|
59
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
56
60
|
Returns:
|
|
57
61
|
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
58
62
|
"""
|
|
@@ -72,6 +76,7 @@ def paged_sliding_window_attn_prefill_fake(
|
|
|
72
76
|
block_table: Tensor,
|
|
73
77
|
block_size: int,
|
|
74
78
|
is_bidirectional: bool,
|
|
79
|
+
s_aux: Optional[Tensor] = None,
|
|
75
80
|
) -> Tensor:
|
|
76
81
|
return torch.empty_like(q)
|
|
77
82
|
|
|
@@ -91,6 +96,8 @@ def paged_sliding_window_attn_decode(
|
|
|
91
96
|
scale: Tensor,
|
|
92
97
|
block_table: Tensor,
|
|
93
98
|
block_size: int,
|
|
99
|
+
attn_mask: Tensor,
|
|
100
|
+
s_aux: Optional[Tensor] = None,
|
|
94
101
|
) -> Tensor:
|
|
95
102
|
return torch.empty_like(q)
|
|
96
103
|
|
|
@@ -107,5 +114,7 @@ def paged_sliding_window_attn_decode_fake(
|
|
|
107
114
|
scale: Tensor,
|
|
108
115
|
block_table: Tensor,
|
|
109
116
|
block_size: int,
|
|
117
|
+
attn_mask: Tensor,
|
|
118
|
+
s_aux: Optional[Tensor] = None,
|
|
110
119
|
) -> Tensor:
|
|
111
120
|
return torch.empty_like(q)
|