optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch.library.custom_op(
|
|
22
|
+
"rbln_custom_ops::paged_flash_attn_decode",
|
|
23
|
+
mutates_args=(["kcache", "vcache"]),
|
|
24
|
+
)
|
|
25
|
+
def paged_flash_attn_decode(
|
|
26
|
+
q: Tensor,
|
|
27
|
+
k: Tensor,
|
|
28
|
+
v: Tensor,
|
|
29
|
+
mask: Tensor,
|
|
30
|
+
kcache: Tensor,
|
|
31
|
+
vcache: Tensor,
|
|
32
|
+
seq: Tensor,
|
|
33
|
+
scale: Tensor,
|
|
34
|
+
block_table: Tensor,
|
|
35
|
+
block_size: int,
|
|
36
|
+
partition: int,
|
|
37
|
+
) -> Tensor:
|
|
38
|
+
"""Defines the computation pattern for fused flash attention with KV cache for decoding.
|
|
39
|
+
|
|
40
|
+
Returns a tensor with the same shape as q.
|
|
41
|
+
"""
|
|
42
|
+
return torch.empty_like(q)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@paged_flash_attn_decode.register_fake
|
|
46
|
+
def paged_flash_attn_decode_fake(
|
|
47
|
+
q: Tensor,
|
|
48
|
+
k: Tensor,
|
|
49
|
+
v: Tensor,
|
|
50
|
+
mask: Tensor,
|
|
51
|
+
kcache: Tensor,
|
|
52
|
+
vcache: Tensor,
|
|
53
|
+
seq: Tensor,
|
|
54
|
+
scale: Tensor,
|
|
55
|
+
block_table: Tensor,
|
|
56
|
+
block_size: int,
|
|
57
|
+
partition: int,
|
|
58
|
+
) -> Tensor:
|
|
59
|
+
return torch.empty_like(q)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@torch.library.custom_op(
|
|
63
|
+
"rbln_custom_ops::paged_flash_attn_decode_kv_fp8",
|
|
64
|
+
mutates_args=(["kcache", "vcache"]),
|
|
65
|
+
)
|
|
66
|
+
def paged_flash_attn_decode_kv_fp8(
|
|
67
|
+
q: Tensor,
|
|
68
|
+
k: Tensor,
|
|
69
|
+
v: Tensor,
|
|
70
|
+
mask: Tensor,
|
|
71
|
+
kcache: Tensor,
|
|
72
|
+
vcache: Tensor,
|
|
73
|
+
seq: Tensor,
|
|
74
|
+
scale: Tensor,
|
|
75
|
+
block_table: Tensor,
|
|
76
|
+
block_size: int,
|
|
77
|
+
partition: int,
|
|
78
|
+
k_scale: Tensor,
|
|
79
|
+
v_scale: Tensor,
|
|
80
|
+
) -> Tensor:
|
|
81
|
+
return torch.empty_like(q)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@paged_flash_attn_decode_kv_fp8.register_fake
|
|
85
|
+
def paged_flash_attn_decode_kv_fp8_fake(
|
|
86
|
+
q: Tensor,
|
|
87
|
+
k: Tensor,
|
|
88
|
+
v: Tensor,
|
|
89
|
+
mask: Tensor,
|
|
90
|
+
kcache: Tensor,
|
|
91
|
+
vcache: Tensor,
|
|
92
|
+
seq: Tensor,
|
|
93
|
+
scale: Tensor,
|
|
94
|
+
block_table: Tensor,
|
|
95
|
+
block_size: int,
|
|
96
|
+
partition: int,
|
|
97
|
+
k_scale: Tensor,
|
|
98
|
+
v_scale: Tensor,
|
|
99
|
+
) -> Tensor:
|
|
100
|
+
return torch.empty_like(q)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@torch.library.custom_op(
|
|
104
|
+
"rbln_custom_ops::paged_flash_attn_prefill",
|
|
105
|
+
mutates_args=(["kcache", "vcache"]),
|
|
106
|
+
)
|
|
107
|
+
def paged_flash_attn_prefill(
|
|
108
|
+
q: Tensor,
|
|
109
|
+
k: Tensor,
|
|
110
|
+
v: Tensor,
|
|
111
|
+
mask: Tensor,
|
|
112
|
+
kcache: Tensor,
|
|
113
|
+
vcache: Tensor,
|
|
114
|
+
seq: Tensor,
|
|
115
|
+
scale: Tensor,
|
|
116
|
+
block_table: Tensor,
|
|
117
|
+
block_size: int,
|
|
118
|
+
partition: int,
|
|
119
|
+
) -> Tensor:
|
|
120
|
+
"""Defines the computation pattern for fused flash attention with KV cache for prefill.
|
|
121
|
+
|
|
122
|
+
Returns a tensor with the same shape as q.
|
|
123
|
+
"""
|
|
124
|
+
return torch.empty_like(q)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@paged_flash_attn_prefill.register_fake
|
|
128
|
+
def paged_flash_attn_prefill_fake(
|
|
129
|
+
q: Tensor,
|
|
130
|
+
k: Tensor,
|
|
131
|
+
v: Tensor,
|
|
132
|
+
mask: Tensor,
|
|
133
|
+
kcache: Tensor,
|
|
134
|
+
vcache: Tensor,
|
|
135
|
+
seq: Tensor,
|
|
136
|
+
scale: Tensor,
|
|
137
|
+
block_table: Tensor,
|
|
138
|
+
block_size: int,
|
|
139
|
+
partition: int,
|
|
140
|
+
) -> Tensor:
|
|
141
|
+
return torch.empty_like(q)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@torch.library.custom_op(
|
|
145
|
+
"rbln_custom_ops::paged_flash_attn_prefill_kv_fp8",
|
|
146
|
+
mutates_args=(["kcache", "vcache"]),
|
|
147
|
+
)
|
|
148
|
+
def paged_flash_attn_prefill_kv_fp8(
|
|
149
|
+
q: Tensor,
|
|
150
|
+
k: Tensor,
|
|
151
|
+
v: Tensor,
|
|
152
|
+
mask: Tensor,
|
|
153
|
+
kcache: Tensor,
|
|
154
|
+
vcache: Tensor,
|
|
155
|
+
seq: Tensor,
|
|
156
|
+
scale: Tensor,
|
|
157
|
+
block_table: Tensor,
|
|
158
|
+
block_size: int,
|
|
159
|
+
partition: int,
|
|
160
|
+
k_scale: Tensor,
|
|
161
|
+
v_scale: Tensor,
|
|
162
|
+
) -> Tensor:
|
|
163
|
+
return torch.empty_like(q)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@paged_flash_attn_prefill_kv_fp8.register_fake
|
|
167
|
+
def paged_flash_attn_prefill_kv_fp8_fake(
|
|
168
|
+
q: Tensor,
|
|
169
|
+
k: Tensor,
|
|
170
|
+
v: Tensor,
|
|
171
|
+
mask: Tensor,
|
|
172
|
+
kcache: Tensor,
|
|
173
|
+
vcache: Tensor,
|
|
174
|
+
seq: Tensor,
|
|
175
|
+
scale: Tensor,
|
|
176
|
+
block_table: Tensor,
|
|
177
|
+
block_size: int,
|
|
178
|
+
partition: int,
|
|
179
|
+
k_scale: Tensor,
|
|
180
|
+
v_scale: Tensor,
|
|
181
|
+
) -> Tensor:
|
|
182
|
+
return torch.empty_like(q)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@torch.library.custom_op(
|
|
186
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
|
187
|
+
mutates_args=(["kcache", "vcache"]),
|
|
188
|
+
)
|
|
189
|
+
def paged_flash_causal_attn_decode(
|
|
190
|
+
q: Tensor,
|
|
191
|
+
k: Tensor,
|
|
192
|
+
v: Tensor,
|
|
193
|
+
kcache: Tensor,
|
|
194
|
+
vcache: Tensor,
|
|
195
|
+
seq: Tensor,
|
|
196
|
+
scale: Tensor,
|
|
197
|
+
block_table: Tensor,
|
|
198
|
+
block_size: int,
|
|
199
|
+
partition: int,
|
|
200
|
+
mask: Optional[Tensor] = None,
|
|
201
|
+
) -> Tensor:
|
|
202
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
|
203
|
+
|
|
204
|
+
Returns a tensor with the same shape as q.
|
|
205
|
+
"""
|
|
206
|
+
return torch.empty_like(q)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@paged_flash_causal_attn_decode.register_fake
|
|
210
|
+
def paged_flash_causal_attn_decode_fake(
|
|
211
|
+
q: Tensor,
|
|
212
|
+
k: Tensor,
|
|
213
|
+
v: Tensor,
|
|
214
|
+
kcache: Tensor,
|
|
215
|
+
vcache: Tensor,
|
|
216
|
+
seq: Tensor,
|
|
217
|
+
scale: Tensor,
|
|
218
|
+
block_table: Tensor,
|
|
219
|
+
block_size: int,
|
|
220
|
+
partition: int,
|
|
221
|
+
mask: Optional[Tensor] = None,
|
|
222
|
+
) -> Tensor:
|
|
223
|
+
return torch.empty_like(q)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@torch.library.custom_op(
|
|
227
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode_kv_fp8",
|
|
228
|
+
mutates_args=(["kcache", "vcache"]),
|
|
229
|
+
)
|
|
230
|
+
def paged_flash_causal_attn_decode_kv_fp8(
|
|
231
|
+
q: Tensor,
|
|
232
|
+
k: Tensor,
|
|
233
|
+
v: Tensor,
|
|
234
|
+
kcache: Tensor,
|
|
235
|
+
vcache: Tensor,
|
|
236
|
+
seq: Tensor,
|
|
237
|
+
scale: Tensor,
|
|
238
|
+
block_table: Tensor,
|
|
239
|
+
block_size: int,
|
|
240
|
+
partition: int,
|
|
241
|
+
k_scale: Tensor,
|
|
242
|
+
v_scale: Tensor,
|
|
243
|
+
mask: Optional[Tensor] = None,
|
|
244
|
+
) -> Tensor:
|
|
245
|
+
return torch.empty_like(q)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@paged_flash_causal_attn_decode_kv_fp8.register_fake
|
|
249
|
+
def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
250
|
+
q: Tensor,
|
|
251
|
+
k: Tensor,
|
|
252
|
+
v: Tensor,
|
|
253
|
+
kcache: Tensor,
|
|
254
|
+
vcache: Tensor,
|
|
255
|
+
seq: Tensor,
|
|
256
|
+
scale: Tensor,
|
|
257
|
+
block_table: Tensor,
|
|
258
|
+
block_size: int,
|
|
259
|
+
partition: int,
|
|
260
|
+
k_scale: Tensor,
|
|
261
|
+
v_scale: Tensor,
|
|
262
|
+
mask: Optional[Tensor] = None,
|
|
263
|
+
) -> Tensor:
|
|
264
|
+
return torch.empty_like(q)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@torch.library.custom_op(
|
|
268
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
|
269
|
+
mutates_args=(["kcache", "vcache"]),
|
|
270
|
+
)
|
|
271
|
+
def paged_flash_causal_attn_prefill(
|
|
272
|
+
q: Tensor,
|
|
273
|
+
k: Tensor,
|
|
274
|
+
v: Tensor,
|
|
275
|
+
kcache: Tensor,
|
|
276
|
+
vcache: Tensor,
|
|
277
|
+
seq: Tensor,
|
|
278
|
+
scale: Tensor,
|
|
279
|
+
block_table: Tensor,
|
|
280
|
+
block_size: int,
|
|
281
|
+
partition: int,
|
|
282
|
+
is_bidirectional: bool,
|
|
283
|
+
mask: Optional[Tensor] = None,
|
|
284
|
+
) -> Tensor:
|
|
285
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
|
286
|
+
|
|
287
|
+
Returns a tensor with the same shape as q.
|
|
288
|
+
"""
|
|
289
|
+
return torch.empty_like(q)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
@paged_flash_causal_attn_prefill.register_fake
|
|
293
|
+
def paged_flash_causal_attn_prefill_fake(
|
|
294
|
+
q: Tensor,
|
|
295
|
+
k: Tensor,
|
|
296
|
+
v: Tensor,
|
|
297
|
+
kcache: Tensor,
|
|
298
|
+
vcache: Tensor,
|
|
299
|
+
seq: Tensor,
|
|
300
|
+
scale: Tensor,
|
|
301
|
+
block_table: Tensor,
|
|
302
|
+
block_size: int,
|
|
303
|
+
partition: int,
|
|
304
|
+
is_bidirectional: bool,
|
|
305
|
+
mask: Optional[Tensor] = None,
|
|
306
|
+
) -> Tensor:
|
|
307
|
+
return torch.empty_like(q)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.library.custom_op(
|
|
311
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill_kv_fp8",
|
|
312
|
+
mutates_args=(["kcache", "vcache"]),
|
|
313
|
+
)
|
|
314
|
+
def paged_flash_causal_attn_prefill_kv_fp8(
|
|
315
|
+
q: Tensor,
|
|
316
|
+
k: Tensor,
|
|
317
|
+
v: Tensor,
|
|
318
|
+
kcache: Tensor,
|
|
319
|
+
vcache: Tensor,
|
|
320
|
+
seq: Tensor,
|
|
321
|
+
scale: Tensor,
|
|
322
|
+
block_table: Tensor,
|
|
323
|
+
block_size: int,
|
|
324
|
+
partition: int,
|
|
325
|
+
is_bidirectional: bool,
|
|
326
|
+
k_scale: Tensor,
|
|
327
|
+
v_scale: Tensor,
|
|
328
|
+
mask: Optional[Tensor] = None,
|
|
329
|
+
) -> Tensor:
|
|
330
|
+
return torch.empty_like(q)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
@paged_flash_causal_attn_prefill_kv_fp8.register_fake
|
|
334
|
+
def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
335
|
+
q: Tensor,
|
|
336
|
+
k: Tensor,
|
|
337
|
+
v: Tensor,
|
|
338
|
+
kcache: Tensor,
|
|
339
|
+
vcache: Tensor,
|
|
340
|
+
seq: Tensor,
|
|
341
|
+
scale: Tensor,
|
|
342
|
+
block_table: Tensor,
|
|
343
|
+
block_size: int,
|
|
344
|
+
partition: int,
|
|
345
|
+
is_bidirectional: bool,
|
|
346
|
+
k_scale: Tensor,
|
|
347
|
+
v_scale: Tensor,
|
|
348
|
+
mask: Optional[Tensor] = None,
|
|
349
|
+
) -> Tensor:
|
|
350
|
+
return torch.empty_like(q)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import Tensor
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
|
|
20
|
+
def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
|
21
|
+
# Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
|
|
22
|
+
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
|
23
|
+
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
|
24
|
+
return torch.empty_like(cache)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@rbln_cache_update.register_fake
|
|
28
|
+
def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
|
29
|
+
return torch.empty_like(cache)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
|
|
22
|
+
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
|
23
|
+
output_shape = list(input.shape[:-1])
|
|
24
|
+
output_shape += [weight.shape[0]]
|
|
25
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@linear.register_fake
|
|
29
|
+
def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
|
30
|
+
output_shape = list(input.shape[:-1])
|
|
31
|
+
output_shape += [weight.shape[0]]
|
|
32
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@torch.library.custom_op(
|
|
21
|
+
"rbln_custom_ops::paged_sliding_window_attn_prefill",
|
|
22
|
+
mutates_args=(["kcache", "vcache"]),
|
|
23
|
+
)
|
|
24
|
+
def paged_sliding_window_attn_prefill(
|
|
25
|
+
q: Tensor,
|
|
26
|
+
k: Tensor,
|
|
27
|
+
v: Tensor,
|
|
28
|
+
kcache: Tensor,
|
|
29
|
+
vcache: Tensor,
|
|
30
|
+
cache_seq_len: Tensor,
|
|
31
|
+
cache_offset: Tensor,
|
|
32
|
+
scale: Tensor,
|
|
33
|
+
block_table: Tensor,
|
|
34
|
+
block_size: int,
|
|
35
|
+
is_bidirectional: bool,
|
|
36
|
+
) -> Tensor:
|
|
37
|
+
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
38
|
+
|
|
39
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
|
40
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
|
41
|
+
|
|
42
|
+
Key differences from decode pattern:
|
|
43
|
+
- Handles prefill phase with multiple input tokens
|
|
44
|
+
- Takes explicit batch index for continuous batching
|
|
45
|
+
|
|
46
|
+
Expected tensor shapes:
|
|
47
|
+
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
|
48
|
+
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
|
49
|
+
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
|
50
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
|
51
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
|
52
|
+
- cache_seq_len: [] - the sequence length of the cached states that were seen by the model
|
|
53
|
+
- cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
|
|
54
|
+
- scale: [] - Attention scale factor
|
|
55
|
+
- is_bidirectional: [] - Whether the attention is bidirectional
|
|
56
|
+
Returns:
|
|
57
|
+
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
58
|
+
"""
|
|
59
|
+
return torch.empty_like(q)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@paged_sliding_window_attn_prefill.register_fake
|
|
63
|
+
def paged_sliding_window_attn_prefill_fake(
|
|
64
|
+
q: Tensor,
|
|
65
|
+
k: Tensor,
|
|
66
|
+
v: Tensor,
|
|
67
|
+
kcache: Tensor,
|
|
68
|
+
vcache: Tensor,
|
|
69
|
+
cache_seq_len: Tensor,
|
|
70
|
+
cache_offset: Tensor,
|
|
71
|
+
scale: Tensor,
|
|
72
|
+
block_table: Tensor,
|
|
73
|
+
block_size: int,
|
|
74
|
+
is_bidirectional: bool,
|
|
75
|
+
) -> Tensor:
|
|
76
|
+
return torch.empty_like(q)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@torch.library.custom_op(
|
|
80
|
+
"rbln_custom_ops::paged_sliding_window_attn_decode",
|
|
81
|
+
mutates_args=(["kcache", "vcache"]),
|
|
82
|
+
)
|
|
83
|
+
def paged_sliding_window_attn_decode(
|
|
84
|
+
q: Tensor,
|
|
85
|
+
k: Tensor,
|
|
86
|
+
v: Tensor,
|
|
87
|
+
kcache: Tensor,
|
|
88
|
+
vcache: Tensor,
|
|
89
|
+
cache_seq_len: Tensor,
|
|
90
|
+
cache_offset: Tensor,
|
|
91
|
+
scale: Tensor,
|
|
92
|
+
block_table: Tensor,
|
|
93
|
+
block_size: int,
|
|
94
|
+
) -> Tensor:
|
|
95
|
+
return torch.empty_like(q)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@paged_sliding_window_attn_decode.register_fake
|
|
99
|
+
def paged_sliding_window_attn_decode_fake(
|
|
100
|
+
q: Tensor,
|
|
101
|
+
k: Tensor,
|
|
102
|
+
v: Tensor,
|
|
103
|
+
kcache: Tensor,
|
|
104
|
+
vcache: Tensor,
|
|
105
|
+
cache_seq_len: Tensor,
|
|
106
|
+
cache_offset: Tensor,
|
|
107
|
+
scale: Tensor,
|
|
108
|
+
block_table: Tensor,
|
|
109
|
+
block_size: int,
|
|
110
|
+
) -> Tensor:
|
|
111
|
+
return torch.empty_like(q)
|