optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
|
@@ -23,9 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
23
23
|
import rebel
|
|
24
24
|
import torch
|
|
25
25
|
from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
|
|
26
|
+
from transformers.utils.hub import PushToHubMixin
|
|
26
27
|
|
|
27
28
|
from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
|
|
28
|
-
from .utils.hub import
|
|
29
|
+
from .utils.hub import pull_compiled_model_from_hub, validate_files
|
|
29
30
|
from .utils.logging import get_logger
|
|
30
31
|
from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
|
|
31
32
|
from .utils.save_utils import maybe_load_preprocessors
|
|
@@ -33,7 +34,7 @@ from .utils.submodule import SubModulesMixin
|
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
if TYPE_CHECKING:
|
|
36
|
-
from transformers import PreTrainedModel
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
37
38
|
|
|
38
39
|
logger = get_logger(__name__)
|
|
39
40
|
|
|
@@ -50,11 +51,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
50
51
|
model_type = "rbln_model"
|
|
51
52
|
auto_model_class = AutoModel
|
|
52
53
|
config_class = AutoConfig
|
|
53
|
-
|
|
54
54
|
config_name = "config.json"
|
|
55
55
|
hf_library_name = "transformers"
|
|
56
|
-
|
|
57
|
-
_rbln_config_class = None
|
|
56
|
+
_supports_non_fp32 = False
|
|
58
57
|
|
|
59
58
|
def __init__(
|
|
60
59
|
self,
|
|
@@ -72,7 +71,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
72
71
|
self.rbln_config = rbln_config
|
|
73
72
|
if not rbln_config.is_frozen():
|
|
74
73
|
raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
|
|
75
|
-
|
|
76
74
|
self.compiled_models = rbln_compiled_models
|
|
77
75
|
|
|
78
76
|
# Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
|
|
@@ -93,7 +91,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
93
91
|
|
|
94
92
|
self.device = torch.device("cpu")
|
|
95
93
|
self.training = False
|
|
96
|
-
self.dtype =
|
|
94
|
+
self.dtype = rbln_config.torch_dtype
|
|
97
95
|
|
|
98
96
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
99
97
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -115,7 +113,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
115
113
|
def _load_compiled_model_dir(
|
|
116
114
|
cls,
|
|
117
115
|
model_id: Union[str, Path],
|
|
118
|
-
|
|
116
|
+
token: Optional[Union[bool, str]] = None,
|
|
119
117
|
revision: Optional[str] = None,
|
|
120
118
|
force_download: bool = False,
|
|
121
119
|
cache_dir: Optional[str] = None,
|
|
@@ -134,7 +132,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
134
132
|
model_path = pull_compiled_model_from_hub(
|
|
135
133
|
model_id=model_id,
|
|
136
134
|
subfolder=subfolder,
|
|
137
|
-
|
|
135
|
+
token=token,
|
|
138
136
|
revision=revision,
|
|
139
137
|
cache_dir=cache_dir,
|
|
140
138
|
force_download=force_download,
|
|
@@ -172,7 +170,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
172
170
|
cls,
|
|
173
171
|
model_id: Union[str, Path],
|
|
174
172
|
config: Optional["PretrainedConfig"] = None,
|
|
175
|
-
|
|
173
|
+
token: Optional[Union[bool, str]] = None,
|
|
176
174
|
revision: Optional[str] = None,
|
|
177
175
|
force_download: bool = False,
|
|
178
176
|
cache_dir: Optional[str] = None,
|
|
@@ -189,7 +187,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
189
187
|
if rbln_compiled_models is None:
|
|
190
188
|
model_path_subfolder = cls._load_compiled_model_dir(
|
|
191
189
|
model_id=model_id,
|
|
192
|
-
|
|
190
|
+
token=token,
|
|
193
191
|
revision=revision,
|
|
194
192
|
force_download=force_download,
|
|
195
193
|
cache_dir=cache_dir,
|
|
@@ -232,7 +230,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
232
230
|
cache_dir=cache_dir,
|
|
233
231
|
force_download=force_download,
|
|
234
232
|
revision=revision,
|
|
235
|
-
token=
|
|
233
|
+
token=token,
|
|
236
234
|
trust_remote_code=trust_remote_code,
|
|
237
235
|
)
|
|
238
236
|
elif cls.hf_library_name == "diffusers":
|
|
@@ -250,7 +248,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
250
248
|
force_download=force_download,
|
|
251
249
|
local_files_only=local_files_only,
|
|
252
250
|
revision=revision,
|
|
253
|
-
token=
|
|
251
|
+
token=token,
|
|
254
252
|
subfolder=subfolder,
|
|
255
253
|
)
|
|
256
254
|
config = PretrainedConfig(**config)
|
|
@@ -316,7 +314,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
316
314
|
rbln_config,
|
|
317
315
|
model_save_dir=model_save_dir,
|
|
318
316
|
subfolder=subfolder,
|
|
319
|
-
rbln_compiled_models=
|
|
317
|
+
rbln_compiled_models=rbln_compiled_models,
|
|
320
318
|
rbln_submodules=rbln_submodules,
|
|
321
319
|
**kwargs,
|
|
322
320
|
)
|
|
@@ -344,32 +342,72 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
344
342
|
rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
|
|
345
343
|
return rbln_config, kwargs
|
|
346
344
|
|
|
345
|
+
@classmethod
|
|
346
|
+
def _is_compiled(
|
|
347
|
+
cls,
|
|
348
|
+
model_id: Union[str, Path],
|
|
349
|
+
token: Optional[Union[bool, str]] = None,
|
|
350
|
+
revision: Optional[str] = None,
|
|
351
|
+
force_download: bool = False,
|
|
352
|
+
cache_dir: Optional[str] = None,
|
|
353
|
+
subfolder: str = "",
|
|
354
|
+
local_files_only: bool = False,
|
|
355
|
+
) -> bool:
|
|
356
|
+
# Check if the model is already compiled.
|
|
357
|
+
try:
|
|
358
|
+
cls._load_compiled_model_dir(
|
|
359
|
+
model_id=model_id,
|
|
360
|
+
token=token,
|
|
361
|
+
revision=revision,
|
|
362
|
+
force_download=force_download,
|
|
363
|
+
cache_dir=cache_dir,
|
|
364
|
+
subfolder=subfolder,
|
|
365
|
+
local_files_only=local_files_only,
|
|
366
|
+
)
|
|
367
|
+
return True
|
|
368
|
+
except (FileNotFoundError, KeyError):
|
|
369
|
+
return False
|
|
370
|
+
|
|
347
371
|
@classmethod
|
|
348
372
|
def from_pretrained(
|
|
349
373
|
cls: Type["RBLNBaseModel"],
|
|
350
374
|
model_id: Union[str, Path],
|
|
351
|
-
export: bool =
|
|
375
|
+
export: Optional[bool] = None,
|
|
352
376
|
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
353
|
-
**kwargs:
|
|
377
|
+
**kwargs: Any,
|
|
354
378
|
) -> "RBLNBaseModel":
|
|
355
379
|
"""
|
|
356
380
|
The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
|
|
357
381
|
User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
|
|
358
382
|
|
|
359
383
|
Args:
|
|
360
|
-
model_id: The model id of the pre-trained model to be loaded.
|
|
361
|
-
|
|
362
|
-
|
|
384
|
+
model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
|
|
385
|
+
It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
|
|
386
|
+
export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
|
|
387
|
+
If None, it will be determined based on the existence of the compiled model files in the model_id.
|
|
388
|
+
rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
|
|
389
|
+
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
363
390
|
For detailed configuration options, see the specific model's configuration class documentation.
|
|
364
|
-
|
|
365
|
-
kwargs: Additional keyword arguments. Arguments with the prefix 'rbln_' are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
391
|
+
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
366
392
|
|
|
367
393
|
Returns:
|
|
368
|
-
A RBLN model instance ready for inference on RBLN NPU devices.
|
|
394
|
+
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
369
395
|
"""
|
|
370
396
|
|
|
371
397
|
if isinstance(model_id, Path):
|
|
372
398
|
model_id = model_id.as_posix()
|
|
399
|
+
|
|
400
|
+
if export is None:
|
|
401
|
+
export = not cls._is_compiled(
|
|
402
|
+
model_id=model_id,
|
|
403
|
+
token=kwargs.get("token"),
|
|
404
|
+
revision=kwargs.get("revision"),
|
|
405
|
+
force_download=kwargs.get("force_download", False),
|
|
406
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
407
|
+
subfolder=kwargs.get("subfolder", ""),
|
|
408
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
409
|
+
)
|
|
410
|
+
|
|
373
411
|
from_pretrained_method = cls._export if export else cls._from_pretrained
|
|
374
412
|
return from_pretrained_method(model_id=model_id, **kwargs, rbln_config=rbln_config)
|
|
375
413
|
|
|
@@ -394,7 +432,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
394
432
|
compiled_model = rebel.compile_from_torch(
|
|
395
433
|
model,
|
|
396
434
|
input_info=rbln_compile_config.input_info,
|
|
397
|
-
fusion=rbln_compile_config.fusion,
|
|
398
435
|
npu=rbln_compile_config.npu,
|
|
399
436
|
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
|
400
437
|
**kwargs,
|
|
@@ -402,8 +439,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
402
439
|
return compiled_model
|
|
403
440
|
|
|
404
441
|
@classmethod
|
|
405
|
-
def update_rbln_config(
|
|
406
|
-
|
|
442
|
+
def update_rbln_config(
|
|
443
|
+
cls,
|
|
444
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
445
|
+
model: "PreTrainedModel",
|
|
446
|
+
model_config: "PretrainedConfig",
|
|
447
|
+
rbln_config: RBLNModelConfig,
|
|
448
|
+
) -> RBLNModelConfig:
|
|
449
|
+
rbln_config.torch_dtype = model.dtype
|
|
450
|
+
if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
|
|
451
|
+
raise NotImplementedError(
|
|
452
|
+
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
453
|
+
)
|
|
454
|
+
rbln_config = cls._update_rbln_config(
|
|
455
|
+
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
456
|
+
)
|
|
407
457
|
rbln_config.freeze()
|
|
408
458
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
409
459
|
raise NameError(
|
|
@@ -421,7 +471,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
421
471
|
|
|
422
472
|
# Returns:
|
|
423
473
|
# type: The original HuggingFace model class
|
|
424
|
-
if cls._hf_class is None:
|
|
474
|
+
if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
|
|
425
475
|
hf_cls_name = cls.__name__[4:]
|
|
426
476
|
library = importlib.import_module(cls.hf_library_name)
|
|
427
477
|
cls._hf_class = getattr(library, hf_cls_name, None)
|
|
@@ -430,7 +480,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
430
480
|
@classmethod
|
|
431
481
|
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
|
432
482
|
# Lazily loads and caches the corresponding RBLN model config class.
|
|
433
|
-
if cls._rbln_config_class is None:
|
|
483
|
+
if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
|
|
434
484
|
rbln_config_class_name = cls.__name__ + "Config"
|
|
435
485
|
cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
|
|
436
486
|
return cls._rbln_config_class
|
|
@@ -446,12 +496,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
446
496
|
|
|
447
497
|
# This method mimics the interface of torch.nn.Module.parameters()
|
|
448
498
|
# specifically for code that uses `next(model.parameters())` to infer
|
|
449
|
-
# the device or dtype. It yields a single dummy tensor on CPU with
|
|
499
|
+
# the device or dtype. It yields a single dummy tensor on CPU with model dtype.
|
|
450
500
|
|
|
451
501
|
# Warning:
|
|
452
502
|
# This does NOT yield the actual model parameters used by the RBLN runtime.
|
|
453
503
|
# Code relying on iterating through all model parameters will not work as expected.
|
|
454
|
-
yield torch.tensor([1.0], dtype=
|
|
504
|
+
yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
|
|
455
505
|
|
|
456
506
|
def __call__(self, *args, **kwargs):
|
|
457
507
|
return self.forward(*args, **kwargs)
|
|
@@ -486,9 +536,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
486
536
|
[`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
|
|
487
537
|
|
|
488
538
|
Args:
|
|
489
|
-
save_directory (
|
|
539
|
+
save_directory (Union[str, Path]):
|
|
490
540
|
Directory where to save the model file.
|
|
491
|
-
push_to_hub (
|
|
541
|
+
push_to_hub (bool):
|
|
492
542
|
Whether or not to push your model to the HuggingFace model hub after saving it.
|
|
493
543
|
|
|
494
544
|
"""
|
|
@@ -507,6 +557,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
507
557
|
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
|
508
558
|
)
|
|
509
559
|
|
|
560
|
+
if isinstance(self.config, PretrainedConfig):
|
|
561
|
+
self.config.save_pretrained(real_save_dir)
|
|
562
|
+
|
|
510
563
|
if save_directory_path == real_save_dir:
|
|
511
564
|
raise FileExistsError(
|
|
512
565
|
f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
|
|
@@ -522,10 +575,35 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
522
575
|
# First copy everything to a temporary directory
|
|
523
576
|
shutil.copytree(real_save_dir, tmp_dir)
|
|
524
577
|
|
|
525
|
-
# If everything succeeded,
|
|
578
|
+
# If everything succeeded, move files to target directory
|
|
526
579
|
if os.path.exists(save_directory_path):
|
|
527
|
-
|
|
528
|
-
|
|
580
|
+
# Merge files from tmp_dir into existing directory
|
|
581
|
+
def _merge_dir(src_root: str, dst_root: str):
|
|
582
|
+
for name in os.listdir(src_root):
|
|
583
|
+
src_item = os.path.join(src_root, name)
|
|
584
|
+
dst_item = os.path.join(dst_root, name)
|
|
585
|
+
|
|
586
|
+
if os.path.islink(src_item) or os.path.isfile(src_item):
|
|
587
|
+
os.makedirs(os.path.dirname(dst_item), exist_ok=True)
|
|
588
|
+
if os.path.isdir(dst_item) and not os.path.islink(dst_item):
|
|
589
|
+
shutil.rmtree(dst_item)
|
|
590
|
+
os.replace(src_item, dst_item)
|
|
591
|
+
elif os.path.isdir(src_item):
|
|
592
|
+
if os.path.islink(dst_item) or os.path.isfile(dst_item):
|
|
593
|
+
os.remove(dst_item)
|
|
594
|
+
os.makedirs(dst_item, exist_ok=True)
|
|
595
|
+
_merge_dir(src_item, dst_item)
|
|
596
|
+
else:
|
|
597
|
+
# Fallback for special file types
|
|
598
|
+
os.replace(src_item, dst_item)
|
|
599
|
+
|
|
600
|
+
_merge_dir(tmp_dir, str(save_directory_path))
|
|
601
|
+
|
|
602
|
+
# Remove the temporary directory tree after merge
|
|
603
|
+
shutil.rmtree(tmp_dir)
|
|
604
|
+
else:
|
|
605
|
+
# If target doesn't exist, just rename tmp_dir to target
|
|
606
|
+
os.rename(tmp_dir, save_directory_path)
|
|
529
607
|
|
|
530
608
|
except Exception as e:
|
|
531
609
|
# Clean up the temporary directory if anything fails
|
|
@@ -534,7 +612,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
534
612
|
raise e # Re-raise the exception after cleanup
|
|
535
613
|
|
|
536
614
|
if push_to_hub:
|
|
537
|
-
|
|
615
|
+
repo_id = kwargs.pop("repo_id", None)
|
|
616
|
+
if repo_id is None:
|
|
617
|
+
raise ValueError("`repo_id` must be provided to push the model to the HuggingFace model hub.")
|
|
618
|
+
return super().push_to_hub(repo_id=repo_id, **kwargs)
|
|
538
619
|
|
|
539
620
|
@staticmethod
|
|
540
621
|
def _raise_missing_compiled_file_error(missing_files: List[str]):
|
optimum/rbln/ops/attn.py
CHANGED
|
@@ -53,6 +53,45 @@ def paged_attn_decode_fake(
|
|
|
53
53
|
return torch.empty_like(q)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
@torch.library.custom_op(
|
|
57
|
+
"rbln_custom_ops::paged_attn_decode_kv_fp8",
|
|
58
|
+
mutates_args=(["kcache", "vcache"]),
|
|
59
|
+
)
|
|
60
|
+
def paged_attn_decode_kv_fp8(
|
|
61
|
+
q: Tensor,
|
|
62
|
+
k: Tensor,
|
|
63
|
+
v: Tensor,
|
|
64
|
+
mask: Tensor,
|
|
65
|
+
kcache: Tensor,
|
|
66
|
+
vcache: Tensor,
|
|
67
|
+
seq: Tensor,
|
|
68
|
+
scale: Tensor,
|
|
69
|
+
block_table: Tensor,
|
|
70
|
+
block_size: int,
|
|
71
|
+
k_scale: Tensor,
|
|
72
|
+
v_scale: Tensor,
|
|
73
|
+
) -> Tensor:
|
|
74
|
+
return torch.empty_like(q)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@paged_attn_decode_kv_fp8.register_fake
|
|
78
|
+
def paged_attn_decode_kv_fp8_fake(
|
|
79
|
+
q: Tensor,
|
|
80
|
+
k: Tensor,
|
|
81
|
+
v: Tensor,
|
|
82
|
+
mask: Tensor,
|
|
83
|
+
kcache: Tensor,
|
|
84
|
+
vcache: Tensor,
|
|
85
|
+
seq: Tensor,
|
|
86
|
+
scale: Tensor,
|
|
87
|
+
block_table: Tensor,
|
|
88
|
+
block_size: int,
|
|
89
|
+
k_scale: Tensor,
|
|
90
|
+
v_scale: Tensor,
|
|
91
|
+
) -> Tensor:
|
|
92
|
+
return torch.empty_like(q)
|
|
93
|
+
|
|
94
|
+
|
|
56
95
|
@torch.library.custom_op(
|
|
57
96
|
"rbln_custom_ops::paged_attn_prefill",
|
|
58
97
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -112,6 +151,45 @@ def paged_attn_prefill_fake(
|
|
|
112
151
|
return torch.empty_like(q)
|
|
113
152
|
|
|
114
153
|
|
|
154
|
+
@torch.library.custom_op(
|
|
155
|
+
"rbln_custom_ops::paged_attn_prefill_kv_fp8",
|
|
156
|
+
mutates_args=(["kcache", "vcache"]),
|
|
157
|
+
)
|
|
158
|
+
def paged_attn_prefill_kv_fp8(
|
|
159
|
+
q: Tensor,
|
|
160
|
+
k: Tensor,
|
|
161
|
+
v: Tensor,
|
|
162
|
+
mask: Tensor,
|
|
163
|
+
kcache: Tensor,
|
|
164
|
+
vcache: Tensor,
|
|
165
|
+
seq: Tensor,
|
|
166
|
+
scale: Tensor,
|
|
167
|
+
block_table: Tensor,
|
|
168
|
+
block_size: int,
|
|
169
|
+
k_scale: Tensor,
|
|
170
|
+
v_scale: Tensor,
|
|
171
|
+
) -> Tensor:
|
|
172
|
+
return torch.empty_like(q)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@paged_attn_prefill_kv_fp8.register_fake
|
|
176
|
+
def paged_attn_prefill_kv_fp8_fake(
|
|
177
|
+
q: Tensor,
|
|
178
|
+
k: Tensor,
|
|
179
|
+
v: Tensor,
|
|
180
|
+
mask: Tensor,
|
|
181
|
+
kcache: Tensor,
|
|
182
|
+
vcache: Tensor,
|
|
183
|
+
seq: Tensor,
|
|
184
|
+
scale: Tensor,
|
|
185
|
+
block_table: Tensor,
|
|
186
|
+
block_size: int,
|
|
187
|
+
k_scale: Tensor,
|
|
188
|
+
v_scale: Tensor,
|
|
189
|
+
) -> Tensor:
|
|
190
|
+
return torch.empty_like(q)
|
|
191
|
+
|
|
192
|
+
|
|
115
193
|
@torch.library.custom_op(
|
|
116
194
|
"rbln_custom_ops::paged_causal_attn_decode",
|
|
117
195
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -236,6 +314,86 @@ def paged_causal_attn_prefill_fake(
|
|
|
236
314
|
return torch.empty_like(q)
|
|
237
315
|
|
|
238
316
|
|
|
317
|
+
@torch.library.custom_op(
|
|
318
|
+
"rbln_custom_ops::paged_causal_attn_decode_kv_fp8",
|
|
319
|
+
mutates_args=(["kcache", "vcache"]),
|
|
320
|
+
)
|
|
321
|
+
def paged_causal_attn_decode_kv_fp8(
|
|
322
|
+
q: Tensor,
|
|
323
|
+
k: Tensor,
|
|
324
|
+
v: Tensor,
|
|
325
|
+
kcache: Tensor,
|
|
326
|
+
vcache: Tensor,
|
|
327
|
+
seq: Tensor,
|
|
328
|
+
scale: Tensor,
|
|
329
|
+
block_table: Tensor,
|
|
330
|
+
block_size: int,
|
|
331
|
+
k_scale: Tensor,
|
|
332
|
+
v_scale: Tensor,
|
|
333
|
+
mask: Optional[Tensor] = None,
|
|
334
|
+
) -> Tensor:
|
|
335
|
+
return torch.empty_like(q)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@paged_causal_attn_decode_kv_fp8.register_fake
|
|
339
|
+
def paged_causal_attn_decode_kv_fp8_fake(
|
|
340
|
+
q: Tensor,
|
|
341
|
+
k: Tensor,
|
|
342
|
+
v: Tensor,
|
|
343
|
+
kcache: Tensor,
|
|
344
|
+
vcache: Tensor,
|
|
345
|
+
seq: Tensor,
|
|
346
|
+
scale: Tensor,
|
|
347
|
+
block_table: Tensor,
|
|
348
|
+
block_size: int,
|
|
349
|
+
k_scale: Tensor,
|
|
350
|
+
v_scale: Tensor,
|
|
351
|
+
mask: Optional[Tensor] = None,
|
|
352
|
+
) -> Tensor:
|
|
353
|
+
return torch.empty_like(q)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
@torch.library.custom_op(
|
|
357
|
+
"rbln_custom_ops::paged_causal_attn_prefill_kv_fp8",
|
|
358
|
+
mutates_args=(["kcache", "vcache"]),
|
|
359
|
+
)
|
|
360
|
+
def paged_causal_attn_prefill_kv_fp8(
|
|
361
|
+
q: Tensor,
|
|
362
|
+
k: Tensor,
|
|
363
|
+
v: Tensor,
|
|
364
|
+
kcache: Tensor,
|
|
365
|
+
vcache: Tensor,
|
|
366
|
+
seq: Tensor,
|
|
367
|
+
scale: Tensor,
|
|
368
|
+
block_table: Tensor,
|
|
369
|
+
block_size: int,
|
|
370
|
+
is_bidirectional: bool,
|
|
371
|
+
k_scale: Tensor,
|
|
372
|
+
v_scale: Tensor,
|
|
373
|
+
mask: Optional[Tensor] = None,
|
|
374
|
+
) -> Tensor:
|
|
375
|
+
return torch.empty_like(q)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
@paged_causal_attn_prefill_kv_fp8.register_fake
|
|
379
|
+
def paged_causal_attn_prefill_kv_fp8_fake(
|
|
380
|
+
q: Tensor,
|
|
381
|
+
k: Tensor,
|
|
382
|
+
v: Tensor,
|
|
383
|
+
kcache: Tensor,
|
|
384
|
+
vcache: Tensor,
|
|
385
|
+
seq: Tensor,
|
|
386
|
+
scale: Tensor,
|
|
387
|
+
block_table: Tensor,
|
|
388
|
+
block_size: int,
|
|
389
|
+
is_bidirectional: bool,
|
|
390
|
+
k_scale: Tensor,
|
|
391
|
+
v_scale: Tensor,
|
|
392
|
+
mask: Optional[Tensor] = None,
|
|
393
|
+
) -> Tensor:
|
|
394
|
+
return torch.empty_like(q)
|
|
395
|
+
|
|
396
|
+
|
|
239
397
|
@torch.library.custom_op(
|
|
240
398
|
"rbln_custom_ops::paged_add_softmax_attn_decode",
|
|
241
399
|
mutates_args=(["kcache", "vcache"]),
|
optimum/rbln/ops/flash_attn.py
CHANGED
|
@@ -59,6 +59,47 @@ def paged_flash_attn_decode_fake(
|
|
|
59
59
|
return torch.empty_like(q)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
@torch.library.custom_op(
|
|
63
|
+
"rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
|
|
64
|
+
mutates_args=(["kcache", "vcache"]),
|
|
65
|
+
)
|
|
66
|
+
def paged_flash_attn_decode_kv_fp8(
|
|
67
|
+
q: Tensor,
|
|
68
|
+
k: Tensor,
|
|
69
|
+
v: Tensor,
|
|
70
|
+
mask: Tensor,
|
|
71
|
+
kcache: Tensor,
|
|
72
|
+
vcache: Tensor,
|
|
73
|
+
seq: Tensor,
|
|
74
|
+
scale: Tensor,
|
|
75
|
+
block_table: Tensor,
|
|
76
|
+
block_size: int,
|
|
77
|
+
partition: int,
|
|
78
|
+
k_scale: Tensor,
|
|
79
|
+
v_scale: Tensor,
|
|
80
|
+
) -> Tensor:
|
|
81
|
+
return torch.empty_like(q)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@paged_flash_attn_decode_kv_fp8.register_fake
|
|
85
|
+
def paged_flash_attn_decode_kv_fp8_fake(
|
|
86
|
+
q: Tensor,
|
|
87
|
+
k: Tensor,
|
|
88
|
+
v: Tensor,
|
|
89
|
+
mask: Tensor,
|
|
90
|
+
kcache: Tensor,
|
|
91
|
+
vcache: Tensor,
|
|
92
|
+
seq: Tensor,
|
|
93
|
+
scale: Tensor,
|
|
94
|
+
block_table: Tensor,
|
|
95
|
+
block_size: int,
|
|
96
|
+
partition: int,
|
|
97
|
+
k_scale: Tensor,
|
|
98
|
+
v_scale: Tensor,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
return torch.empty_like(q)
|
|
101
|
+
|
|
102
|
+
|
|
62
103
|
@torch.library.custom_op(
|
|
63
104
|
"rbln_custom_ops::paged_flash_attn_prefill",
|
|
64
105
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -100,6 +141,47 @@ def paged_flash_attn_prefill_fake(
|
|
|
100
141
|
return torch.empty_like(q)
|
|
101
142
|
|
|
102
143
|
|
|
144
|
+
@torch.library.custom_op(
|
|
145
|
+
"rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
|
|
146
|
+
mutates_args=(["kcache", "vcache"]),
|
|
147
|
+
)
|
|
148
|
+
def paged_flash_attn_prefill_kv_fp8(
|
|
149
|
+
q: Tensor,
|
|
150
|
+
k: Tensor,
|
|
151
|
+
v: Tensor,
|
|
152
|
+
mask: Tensor,
|
|
153
|
+
kcache: Tensor,
|
|
154
|
+
vcache: Tensor,
|
|
155
|
+
seq: Tensor,
|
|
156
|
+
scale: Tensor,
|
|
157
|
+
block_table: Tensor,
|
|
158
|
+
block_size: int,
|
|
159
|
+
partition: int,
|
|
160
|
+
k_scale: Tensor,
|
|
161
|
+
v_scale: Tensor,
|
|
162
|
+
) -> Tensor:
|
|
163
|
+
return torch.empty_like(q)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@paged_flash_attn_prefill_kv_fp8.register_fake
|
|
167
|
+
def paged_flash_attn_prefill_kv_fp8_fake(
|
|
168
|
+
q: Tensor,
|
|
169
|
+
k: Tensor,
|
|
170
|
+
v: Tensor,
|
|
171
|
+
mask: Tensor,
|
|
172
|
+
kcache: Tensor,
|
|
173
|
+
vcache: Tensor,
|
|
174
|
+
seq: Tensor,
|
|
175
|
+
scale: Tensor,
|
|
176
|
+
block_table: Tensor,
|
|
177
|
+
block_size: int,
|
|
178
|
+
partition: int,
|
|
179
|
+
k_scale: Tensor,
|
|
180
|
+
v_scale: Tensor,
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
return torch.empty_like(q)
|
|
183
|
+
|
|
184
|
+
|
|
103
185
|
@torch.library.custom_op(
|
|
104
186
|
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
|
105
187
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -141,6 +223,47 @@ def paged_flash_causal_attn_decode_fake(
|
|
|
141
223
|
return torch.empty_like(q)
|
|
142
224
|
|
|
143
225
|
|
|
226
|
+
@torch.library.custom_op(
|
|
227
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
|
|
228
|
+
mutates_args=(["kcache", "vcache"]),
|
|
229
|
+
)
|
|
230
|
+
def paged_flash_causal_attn_decode_kv_fp8(
|
|
231
|
+
q: Tensor,
|
|
232
|
+
k: Tensor,
|
|
233
|
+
v: Tensor,
|
|
234
|
+
kcache: Tensor,
|
|
235
|
+
vcache: Tensor,
|
|
236
|
+
seq: Tensor,
|
|
237
|
+
scale: Tensor,
|
|
238
|
+
block_table: Tensor,
|
|
239
|
+
block_size: int,
|
|
240
|
+
partition: int,
|
|
241
|
+
k_scale: Tensor,
|
|
242
|
+
v_scale: Tensor,
|
|
243
|
+
mask: Optional[Tensor] = None,
|
|
244
|
+
) -> Tensor:
|
|
245
|
+
return torch.empty_like(q)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@paged_flash_causal_attn_decode_kv_fp8.register_fake
|
|
249
|
+
def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
250
|
+
q: Tensor,
|
|
251
|
+
k: Tensor,
|
|
252
|
+
v: Tensor,
|
|
253
|
+
kcache: Tensor,
|
|
254
|
+
vcache: Tensor,
|
|
255
|
+
seq: Tensor,
|
|
256
|
+
scale: Tensor,
|
|
257
|
+
block_table: Tensor,
|
|
258
|
+
block_size: int,
|
|
259
|
+
partition: int,
|
|
260
|
+
k_scale: Tensor,
|
|
261
|
+
v_scale: Tensor,
|
|
262
|
+
mask: Optional[Tensor] = None,
|
|
263
|
+
) -> Tensor:
|
|
264
|
+
return torch.empty_like(q)
|
|
265
|
+
|
|
266
|
+
|
|
144
267
|
@torch.library.custom_op(
|
|
145
268
|
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
|
146
269
|
mutates_args=(["kcache", "vcache"]),
|
|
@@ -182,3 +305,46 @@ def paged_flash_causal_attn_prefill_fake(
|
|
|
182
305
|
mask: Optional[Tensor] = None,
|
|
183
306
|
) -> Tensor:
|
|
184
307
|
return torch.empty_like(q)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.library.custom_op(
|
|
311
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
|
|
312
|
+
mutates_args=(["kcache", "vcache"]),
|
|
313
|
+
)
|
|
314
|
+
def paged_flash_causal_attn_prefill_kv_fp8(
|
|
315
|
+
q: Tensor,
|
|
316
|
+
k: Tensor,
|
|
317
|
+
v: Tensor,
|
|
318
|
+
kcache: Tensor,
|
|
319
|
+
vcache: Tensor,
|
|
320
|
+
seq: Tensor,
|
|
321
|
+
scale: Tensor,
|
|
322
|
+
block_table: Tensor,
|
|
323
|
+
block_size: int,
|
|
324
|
+
partition: int,
|
|
325
|
+
is_bidirectional: bool,
|
|
326
|
+
k_scale: Tensor,
|
|
327
|
+
v_scale: Tensor,
|
|
328
|
+
mask: Optional[Tensor] = None,
|
|
329
|
+
) -> Tensor:
|
|
330
|
+
return torch.empty_like(q)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@paged_flash_causal_attn_prefill_kv_fp8.register_fake
|
|
334
|
+
def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
335
|
+
q: Tensor,
|
|
336
|
+
k: Tensor,
|
|
337
|
+
v: Tensor,
|
|
338
|
+
kcache: Tensor,
|
|
339
|
+
vcache: Tensor,
|
|
340
|
+
seq: Tensor,
|
|
341
|
+
scale: Tensor,
|
|
342
|
+
block_table: Tensor,
|
|
343
|
+
block_size: int,
|
|
344
|
+
partition: int,
|
|
345
|
+
is_bidirectional: bool,
|
|
346
|
+
k_scale: Tensor,
|
|
347
|
+
v_scale: Tensor,
|
|
348
|
+
mask: Optional[Tensor] = None,
|
|
349
|
+
) -> Tensor:
|
|
350
|
+
return torch.empty_like(q)
|