optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +32 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/configuration_utils.py +20 -4
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +3 -2
- optimum/rbln/modeling_base.py +29 -4
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/transformers/__init__.py +24 -0
- optimum/rbln/transformers/configuration_generic.py +6 -4
- optimum/rbln/transformers/modeling_generic.py +13 -8
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +31 -16
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +25 -251
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
- optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
- optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
- optimum/rbln/utils/runtime_utils.py +3 -3
- optimum/rbln/utils/submodule.py +10 -4
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,16 +13,19 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import math
|
|
16
|
-
from typing import List, Optional, Tuple, Union
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from torch import nn
|
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
|
-
from ...modeling_attention_utils import DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
24
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
25
|
-
from .
|
|
24
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
|
|
26
29
|
|
|
27
30
|
|
|
28
31
|
logger = logging.get_logger(__name__)
|
|
@@ -42,16 +45,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
42
45
|
- Wrapper should not contain neural network graph operations (including memory view handling)
|
|
43
46
|
|
|
44
47
|
Args:
|
|
45
|
-
|
|
46
|
-
|
|
48
|
+
model (PreTrainedModel): The Huggingface causal language model to wrap
|
|
49
|
+
rbln_config: The RBLN model configuration containing all necessary parameters
|
|
47
50
|
use_rotary_emb (bool): Whether to use rotary position embeddings
|
|
48
|
-
attn_impl (str): The attention implementation to use.
|
|
49
|
-
- "eager": Uses the standard attention.
|
|
50
|
-
- "flash_attn": Uses flash attention. When set,
|
|
51
|
-
the key/value cache is partitioned into chunks of length
|
|
52
|
-
`kvcache_partition_len`.
|
|
53
|
-
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
|
54
|
-
This is only relevant if `attn_impl` is set to "flash_attn`
|
|
55
51
|
"""
|
|
56
52
|
|
|
57
53
|
_use_learned_pos_emb = False
|
|
@@ -59,24 +55,17 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
59
55
|
def __init__(
|
|
60
56
|
self,
|
|
61
57
|
model: PreTrainedModel,
|
|
62
|
-
|
|
58
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
63
59
|
use_rotary_emb: bool,
|
|
64
|
-
attn_impl: str,
|
|
65
|
-
cache_impl: CacheImplType,
|
|
66
|
-
use_inputs_embeds: bool,
|
|
67
|
-
use_attention_mask: bool,
|
|
68
|
-
use_position_ids: bool,
|
|
69
|
-
kvcache_partition_len: Optional[int] = None,
|
|
70
|
-
kvcache_block_size: Optional[int] = None,
|
|
71
|
-
sliding_window: Optional[int] = None,
|
|
72
|
-
sliding_window_layers: Optional[List[int]] = None,
|
|
73
60
|
):
|
|
74
61
|
super().__init__()
|
|
62
|
+
self.quantization = rbln_config.quantization
|
|
75
63
|
self.config = model.config
|
|
76
64
|
self.is_causal_lm = getattr(model, "lm_head", None) is not None
|
|
65
|
+
self.rbln_config = rbln_config
|
|
77
66
|
|
|
78
67
|
if use_rotary_emb:
|
|
79
|
-
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
68
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
|
|
80
69
|
if isinstance(rotary_embs, tuple):
|
|
81
70
|
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
|
82
71
|
else:
|
|
@@ -84,31 +73,13 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
84
73
|
else:
|
|
85
74
|
self.rotary_emb = None
|
|
86
75
|
|
|
87
|
-
|
|
88
|
-
self.kvcache_block_size = kvcache_block_size
|
|
89
|
-
self.use_attention_mask = use_attention_mask
|
|
90
|
-
self.use_position_ids = use_position_ids
|
|
91
|
-
self.use_inputs_embeds = use_inputs_embeds
|
|
92
|
-
self.sliding_window_layers = sliding_window_layers
|
|
93
|
-
self.cache_impl = cache_impl
|
|
94
|
-
self.use_global_attention = cache_impl in ["static", "hybrid"]
|
|
95
|
-
self.use_local_attention = cache_impl in ["hybrid", "sliding_window"]
|
|
96
|
-
self.sliding_window = sliding_window
|
|
97
|
-
|
|
98
|
-
if self.attn_impl == "flash_attn":
|
|
99
|
-
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
100
|
-
elif self.attn_impl == "eager":
|
|
101
|
-
self.kvcache_partition_len = None
|
|
102
|
-
else:
|
|
103
|
-
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
|
104
|
-
|
|
105
|
-
if kvcache_partition_len and kvcache_partition_len > max_seq_len:
|
|
76
|
+
if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
|
|
106
77
|
raise ValueError(
|
|
107
|
-
f"kvcache_partition_len({kvcache_partition_len}) should be lower"
|
|
108
|
-
f" or equal to max_seq_len({max_seq_len})!"
|
|
78
|
+
f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
|
|
79
|
+
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
109
80
|
)
|
|
110
81
|
|
|
111
|
-
self.model = self.convert_to_rbln_class(model, max_seq_len)
|
|
82
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
112
83
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
|
113
84
|
self._phase = "prefill"
|
|
114
85
|
|
|
@@ -139,17 +110,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
139
110
|
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
140
111
|
new_layers = []
|
|
141
112
|
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
142
|
-
is_sliding = layer_idx in self.sliding_window_layers
|
|
113
|
+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
143
114
|
new_self_attn = self.get_rbln_attn_class()(
|
|
144
|
-
self.get_attn_layer(layer),
|
|
145
|
-
self.use_attention_mask if not is_sliding else True,
|
|
146
|
-
self.use_position_ids,
|
|
147
|
-
kvcache_block_size=self.sliding_window
|
|
148
|
-
if layer_idx in self.sliding_window_layers
|
|
149
|
-
else self.kvcache_block_size,
|
|
150
|
-
is_sliding=is_sliding,
|
|
151
|
-
attn_impl=self.attn_impl if not is_sliding else "eager",
|
|
152
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
115
|
+
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
153
116
|
)
|
|
154
117
|
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
155
118
|
new_layers.append(new_layer)
|
|
@@ -157,11 +120,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
157
120
|
new_model = self.get_rbln_model_class()(
|
|
158
121
|
self.get_model_layer(model),
|
|
159
122
|
new_layers,
|
|
160
|
-
|
|
161
|
-
max_seq_len=max_seq_len,
|
|
162
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
123
|
+
self.rbln_config,
|
|
163
124
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
164
|
-
sliding_window_layers=self.sliding_window_layers,
|
|
165
125
|
)
|
|
166
126
|
|
|
167
127
|
if self.is_causal_lm:
|
|
@@ -181,19 +141,19 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
181
141
|
|
|
182
142
|
def prepare_forward_args(self, *args):
|
|
183
143
|
args = list(args)
|
|
184
|
-
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
|
185
|
-
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
|
144
|
+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
145
|
+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
|
|
186
146
|
cache_position = args.pop(0)
|
|
187
|
-
global_block_tables = args.pop(0) if self.use_global_attention else None
|
|
188
|
-
local_block_tables = args.pop(0) if self.use_local_attention else None
|
|
147
|
+
global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
|
|
148
|
+
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
189
149
|
query_position = (
|
|
190
150
|
args.pop(0)
|
|
191
151
|
# query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
|
|
192
|
-
if ("prefill" in self.phase and (self.is_causal_lm or self.use_local_attention))
|
|
152
|
+
if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
|
|
193
153
|
else None
|
|
194
154
|
)
|
|
195
|
-
attention_mask = args.pop(0) if self.use_attention_mask else None
|
|
196
|
-
position_ids = args.pop(0) if self.use_position_ids else None
|
|
155
|
+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
156
|
+
position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
|
|
197
157
|
past_key_values = args
|
|
198
158
|
|
|
199
159
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
@@ -345,6 +305,8 @@ class DecoderOnlyModel(nn.Module):
|
|
|
345
305
|
Args:
|
|
346
306
|
model: Original Huggingface model to adapt
|
|
347
307
|
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
|
308
|
+
rbln_config: RBLN model configuration
|
|
309
|
+
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
348
310
|
|
|
349
311
|
Attributes:
|
|
350
312
|
_original_mod: Reference to original Huggingface model
|
|
@@ -356,21 +318,19 @@ class DecoderOnlyModel(nn.Module):
|
|
|
356
318
|
self,
|
|
357
319
|
model,
|
|
358
320
|
layers: List["DecoderOnlyLayer"],
|
|
359
|
-
|
|
360
|
-
max_seq_len=None,
|
|
361
|
-
kvcache_block_size=None,
|
|
321
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
362
322
|
use_learned_pos_emb=None,
|
|
363
|
-
sliding_window_layers=None,
|
|
364
323
|
):
|
|
365
324
|
super().__init__()
|
|
366
325
|
self._original_mod = model
|
|
367
326
|
self.layers = nn.ModuleList(layers)
|
|
327
|
+
self.rbln_config = rbln_config
|
|
368
328
|
self._phase = "prefill"
|
|
369
|
-
self.partition_len =
|
|
370
|
-
self.kvcache_block_size = kvcache_block_size
|
|
371
|
-
self.max_seq_len = max_seq_len
|
|
329
|
+
self.partition_len = rbln_config.kvcache_partition_len
|
|
330
|
+
self.kvcache_block_size = rbln_config.kvcache_block_size
|
|
331
|
+
self.max_seq_len = rbln_config.max_seq_len
|
|
372
332
|
self.use_learned_pos_emb = use_learned_pos_emb
|
|
373
|
-
self.sliding_window_layers = sliding_window_layers
|
|
333
|
+
self.sliding_window_layers = rbln_config.sliding_window_layers
|
|
374
334
|
|
|
375
335
|
@property
|
|
376
336
|
def phase(self):
|
|
@@ -600,25 +560,19 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
600
560
|
|
|
601
561
|
Args:
|
|
602
562
|
self_attn: Original attention module from the base model
|
|
603
|
-
|
|
604
|
-
use_position_ids: Whether to use position ids
|
|
605
|
-
kvcache_block_size: Block size for KV cache
|
|
563
|
+
rbln_config: RBLN model configuration containing attention parameters
|
|
606
564
|
is_sliding: Whether this is sliding window attention
|
|
607
|
-
attn_impl: Attention implementation type ("eager" or "flash_attn")
|
|
608
565
|
"""
|
|
609
566
|
|
|
610
567
|
def __init__(
|
|
611
568
|
self,
|
|
612
569
|
self_attn,
|
|
613
|
-
|
|
614
|
-
use_position_ids,
|
|
615
|
-
kvcache_block_size,
|
|
570
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
616
571
|
is_sliding=False,
|
|
617
|
-
attn_impl="eager",
|
|
618
|
-
kvcache_partition_len=None,
|
|
619
572
|
):
|
|
620
573
|
super().__init__()
|
|
621
574
|
self._original_mod = self_attn
|
|
575
|
+
self.rbln_config = rbln_config
|
|
622
576
|
self.layer_idx = self_attn.layer_idx
|
|
623
577
|
self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
|
|
624
578
|
self._original_mod.config, "num_attention_heads"
|
|
@@ -626,6 +580,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
626
580
|
self.head_dim = self._original_mod.head_dim
|
|
627
581
|
self._phase = "prefill"
|
|
628
582
|
self.scale = torch.tensor(self.get_attn_scale())
|
|
583
|
+
self.quantization = rbln_config.quantization
|
|
629
584
|
|
|
630
585
|
if hasattr(self._original_mod, "num_key_value_heads"):
|
|
631
586
|
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
|
@@ -634,14 +589,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
634
589
|
else:
|
|
635
590
|
self.num_key_value_heads = self.num_heads
|
|
636
591
|
|
|
637
|
-
self.use_attention_mask = use_attention_mask
|
|
638
|
-
self.use_position_ids = use_position_ids
|
|
592
|
+
self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
|
|
593
|
+
self.use_position_ids = rbln_config.use_position_ids
|
|
639
594
|
self.is_sliding = is_sliding
|
|
640
|
-
self.attn_impl = attn_impl
|
|
641
|
-
self.kvcache_partition_len = kvcache_partition_len
|
|
595
|
+
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
596
|
+
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
597
|
+
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
642
598
|
|
|
643
599
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
644
|
-
self.kvcache_block_size = kvcache_block_size
|
|
645
600
|
self.__post_init__()
|
|
646
601
|
|
|
647
602
|
def get_attention_name(self):
|
|
@@ -681,6 +636,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
681
636
|
self.kvcache_partition_len,
|
|
682
637
|
self.use_attention_mask,
|
|
683
638
|
self.use_position_ids,
|
|
639
|
+
self.quantization,
|
|
684
640
|
)
|
|
685
641
|
elif self.attn_impl == "eager":
|
|
686
642
|
return AttentionOp(
|
|
@@ -689,6 +645,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
689
645
|
self.num_key_value_heads,
|
|
690
646
|
self.use_attention_mask,
|
|
691
647
|
self.use_position_ids,
|
|
648
|
+
self.quantization,
|
|
692
649
|
)
|
|
693
650
|
else:
|
|
694
651
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
@@ -719,6 +676,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
719
676
|
def get_attn_scale(self):
|
|
720
677
|
return 1 / math.sqrt(self.head_dim)
|
|
721
678
|
|
|
679
|
+
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
680
|
+
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
681
|
+
k_scale = getattr(self.k_proj, "k_scale", None)
|
|
682
|
+
v_scale = getattr(self.v_proj, "v_scale", None)
|
|
683
|
+
else:
|
|
684
|
+
k_scale = None
|
|
685
|
+
v_scale = None
|
|
686
|
+
|
|
687
|
+
return k_scale, v_scale
|
|
688
|
+
|
|
722
689
|
def forward(
|
|
723
690
|
self,
|
|
724
691
|
hidden_states: torch.Tensor,
|
|
@@ -748,6 +715,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
748
715
|
if batch_size > 1 and "prefill" in self.phase:
|
|
749
716
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
750
717
|
|
|
718
|
+
k_scale, v_scale = self.maybe_get_kvcache_scale()
|
|
719
|
+
|
|
751
720
|
attn_output = self.get_attention_op()(
|
|
752
721
|
query_states,
|
|
753
722
|
key_states,
|
|
@@ -759,6 +728,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
759
728
|
scale=self.scale,
|
|
760
729
|
block_tables=block_tables,
|
|
761
730
|
block_size=self.kvcache_block_size,
|
|
731
|
+
k_scale=k_scale,
|
|
732
|
+
v_scale=v_scale,
|
|
762
733
|
)
|
|
763
734
|
|
|
764
735
|
attn_outputs = self.o_proj(attn_output)
|
|
@@ -775,7 +746,13 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
|
775
746
|
|
|
776
747
|
class AttentionOp(nn.Module):
|
|
777
748
|
def __init__(
|
|
778
|
-
self,
|
|
749
|
+
self,
|
|
750
|
+
num_heads: int,
|
|
751
|
+
head_dim: int,
|
|
752
|
+
num_key_value_heads: int,
|
|
753
|
+
use_attention_mask: bool,
|
|
754
|
+
use_position_ids: bool,
|
|
755
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
779
756
|
):
|
|
780
757
|
super().__init__()
|
|
781
758
|
self.num_heads = num_heads
|
|
@@ -784,10 +761,10 @@ class AttentionOp(nn.Module):
|
|
|
784
761
|
self.phase = "prefill"
|
|
785
762
|
self.use_attention_mask = use_attention_mask
|
|
786
763
|
self.use_position_ids = use_position_ids
|
|
764
|
+
self.quantization = quantization
|
|
787
765
|
|
|
788
766
|
def get_attn_op_name(self):
|
|
789
767
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
790
|
-
|
|
791
768
|
if self.use_attention_mask and not self.use_position_ids:
|
|
792
769
|
attn_op_name = "paged_attn_"
|
|
793
770
|
else:
|
|
@@ -795,6 +772,9 @@ class AttentionOp(nn.Module):
|
|
|
795
772
|
|
|
796
773
|
attn_op_name += phase
|
|
797
774
|
|
|
775
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
776
|
+
attn_op_name += "_kv_fp8"
|
|
777
|
+
|
|
798
778
|
return attn_op_name
|
|
799
779
|
|
|
800
780
|
def forward(
|
|
@@ -809,6 +789,8 @@ class AttentionOp(nn.Module):
|
|
|
809
789
|
scale: torch.Tensor,
|
|
810
790
|
block_tables: torch.Tensor,
|
|
811
791
|
block_size: int,
|
|
792
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
793
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
812
794
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
813
795
|
"""Compute attention with static shapes and explicit cache management.
|
|
814
796
|
|
|
@@ -821,6 +803,10 @@ class AttentionOp(nn.Module):
|
|
|
821
803
|
past_value_state: Previous value cache states
|
|
822
804
|
seq_position: Current position in sequence
|
|
823
805
|
scale: Scale applied to attn weights
|
|
806
|
+
block_tables: Block tables for paged attention
|
|
807
|
+
block_size: Block size for paged attention
|
|
808
|
+
k_scale: Scale applied to key
|
|
809
|
+
v_scale: Scale applied to value
|
|
824
810
|
|
|
825
811
|
Returns:
|
|
826
812
|
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
@@ -864,6 +850,12 @@ class AttentionOp(nn.Module):
|
|
|
864
850
|
if not self.use_attention_mask or self.use_position_ids:
|
|
865
851
|
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
866
852
|
|
|
853
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
854
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
855
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
856
|
+
op_args["k_scale"] = k_scale
|
|
857
|
+
op_args["v_scale"] = v_scale
|
|
858
|
+
|
|
867
859
|
attn_op_name = self.get_attn_op_name()
|
|
868
860
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
869
861
|
if attn_op is None:
|
|
@@ -886,6 +878,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
886
878
|
kvcache_partition_len: int,
|
|
887
879
|
use_attention_mask: bool,
|
|
888
880
|
use_position_ids: bool,
|
|
881
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
889
882
|
):
|
|
890
883
|
super().__init__(
|
|
891
884
|
num_heads=num_heads,
|
|
@@ -893,6 +886,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
893
886
|
num_key_value_heads=num_key_value_heads,
|
|
894
887
|
use_attention_mask=use_attention_mask,
|
|
895
888
|
use_position_ids=use_position_ids,
|
|
889
|
+
quantization=quantization,
|
|
896
890
|
)
|
|
897
891
|
self.kvcache_partition_size = kvcache_partition_len
|
|
898
892
|
|
|
@@ -905,6 +899,9 @@ class FlashAttentionOp(AttentionOp):
|
|
|
905
899
|
|
|
906
900
|
attn_op_name += phase
|
|
907
901
|
|
|
902
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
903
|
+
attn_op_name += "_kv_fp8"
|
|
904
|
+
|
|
908
905
|
return attn_op_name
|
|
909
906
|
|
|
910
907
|
def forward(
|
|
@@ -919,6 +916,8 @@ class FlashAttentionOp(AttentionOp):
|
|
|
919
916
|
scale,
|
|
920
917
|
block_tables,
|
|
921
918
|
block_size,
|
|
919
|
+
k_scale=None,
|
|
920
|
+
v_scale=None,
|
|
922
921
|
):
|
|
923
922
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
924
923
|
key_state = key_state.unsqueeze(2)
|
|
@@ -959,6 +958,12 @@ class FlashAttentionOp(AttentionOp):
|
|
|
959
958
|
if not self.use_attention_mask or self.use_position_ids:
|
|
960
959
|
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
961
960
|
|
|
961
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
962
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
963
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
964
|
+
op_args["k_scale"] = k_scale
|
|
965
|
+
op_args["v_scale"] = v_scale
|
|
966
|
+
|
|
962
967
|
attn_op_name = self.get_attn_op_name()
|
|
963
968
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
964
969
|
if attn_op is None:
|
|
@@ -986,14 +991,19 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
986
991
|
query_state: torch.Tensor,
|
|
987
992
|
key_state: torch.Tensor,
|
|
988
993
|
value_state: torch.Tensor,
|
|
989
|
-
attn_mask: torch.Tensor,
|
|
994
|
+
attn_mask: Optional[torch.Tensor],
|
|
990
995
|
past_key_state: torch.Tensor,
|
|
991
996
|
past_value_state: torch.Tensor,
|
|
992
997
|
seq_position: Tuple[torch.Tensor],
|
|
993
998
|
scale: torch.Tensor,
|
|
994
999
|
block_tables: torch.Tensor,
|
|
995
1000
|
block_size: int,
|
|
1001
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
1002
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
996
1003
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1004
|
+
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1005
|
+
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
1006
|
+
|
|
997
1007
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
998
1008
|
key_state = key_state.unsqueeze(2)
|
|
999
1009
|
value_state = value_state.unsqueeze(2)
|