optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +53 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Dict, List, Optional, Union
|
15
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
16
16
|
|
17
17
|
import rebel
|
18
18
|
|
@@ -23,8 +23,18 @@ from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
23
23
|
|
24
24
|
logger = get_logger()
|
25
25
|
|
26
|
+
CacheImplType = Literal["static", "sliding_window", "hybrid"]
|
27
|
+
|
26
28
|
|
27
29
|
class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
30
|
+
"""
|
31
|
+
Configuration class for RBLN decoder-only models for Causal Language Modeling.
|
32
|
+
|
33
|
+
This class extends RBLNModelConfig with parameters specific to decoder-only transformer
|
34
|
+
architectures optimized for RBLN devices. It controls aspects like attention implementation,
|
35
|
+
KV cache management, and batching for inference.
|
36
|
+
"""
|
37
|
+
|
28
38
|
def __init__(
|
29
39
|
self,
|
30
40
|
batch_size: Optional[int] = None,
|
@@ -39,36 +49,119 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
39
49
|
prefill_chunk_size: Optional[int] = None,
|
40
50
|
kvcache_num_blocks: Optional[int] = None,
|
41
51
|
decoder_batch_sizes: Optional[List[int]] = None,
|
52
|
+
cache_impl: Optional[CacheImplType] = None,
|
53
|
+
sliding_window: Optional[int] = None,
|
54
|
+
sliding_window_layers: Optional[List[int]] = None,
|
42
55
|
**kwargs,
|
43
56
|
):
|
44
57
|
"""
|
45
58
|
Args:
|
46
59
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
47
60
|
max_seq_len (Optional[int]): The maximum sequence length supported by the model.
|
48
|
-
|
49
|
-
|
50
|
-
|
61
|
+
If not provided, it attempts to infer from the model's configuration
|
62
|
+
(`max_position_embeddings` or `n_positions`). Must be specified if not available
|
63
|
+
in the model config.
|
64
|
+
use_inputs_embeds (Optional[bool]): Whether to use input embeddings (`inputs_embeds`)
|
65
|
+
directly instead of `input_ids`. Defaults to False. Requires the model to be
|
66
|
+
compiled with this option enabled.
|
67
|
+
use_attention_mask (Optional[bool]): Whether the model requires attention masks during
|
68
|
+
inference. This is typically determined based on the target device and model
|
69
|
+
architecture. Defaults are often set automatically based on the model and RBLN NPU.
|
51
70
|
use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
|
52
|
-
attn_impl (Optional[str]):
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
71
|
+
attn_impl (Optional[str]): Specifies the attention implementation to use.
|
72
|
+
See the "Attention Implementation (`attn_impl`)" section below for details.
|
73
|
+
kvcache_partition_len (Optional[int]): Defines the partition length for the KV cache
|
74
|
+
when using "flash_attn". See the "KV Cache Partition Length (`kvcache_partition_len`)"
|
75
|
+
section below for details.
|
76
|
+
kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
|
77
|
+
in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
|
78
|
+
section below for details.
|
79
|
+
quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
|
80
|
+
quantization. Specifies format, etc.
|
81
|
+
prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
|
82
|
+
processing input sequences. Defaults to 128. Must be a positive integer
|
83
|
+
divisible by 64. Affects prefill performance and memory usage.
|
84
|
+
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
85
|
+
PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
|
86
|
+
section below for details.
|
59
87
|
decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
|
60
88
|
This allows the model to handle varying batch sizes efficiently during generation. If not specified,
|
61
89
|
defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
|
62
90
|
1) All values must be less than or equal to the main batch size.
|
63
91
|
2) The list will be sorted in descending order (larger batch sizes first).
|
64
92
|
3) If using multiple decoders, at least one batch size should match the main batch size.
|
65
|
-
|
93
|
+
cache_impl (Optional[CacheImplType]): Specifies the KV cache implementation strategy. Defaults to "static".
|
94
|
+
- "static": Uses a fixed-size global KV cache for all layers, suitable for standard attention patterns.
|
95
|
+
- "sliding_window": Implements a sliding window KV cache, where each layer maintains a local cache of recent tokens.
|
96
|
+
- "hybrid": Combines both static and sliding window approaches, allowing different layers to use different cache strategies.
|
97
|
+
The choice affects memory usage and attention patterns. When using "sliding_window" or "hybrid",
|
98
|
+
you must specify the `sliding_window` size and optionally `sliding_window_layers` for hybrid mode.
|
99
|
+
sliding_window (Optional[int]): The size of the sliding window. Defaults to None.
|
100
|
+
sliding_window_layers (Optional[List[int]]): The layers to use for the sliding window used in the hybrid model. Defaults to None.
|
66
101
|
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
67
102
|
|
68
103
|
Raises:
|
69
|
-
ValueError: If batch_size is not a positive integer
|
70
|
-
|
104
|
+
ValueError: If `batch_size` is not a positive integer.
|
105
|
+
ValueError: If `prefill_chunk_size` is not a positive integer divisible by 64.
|
106
|
+
ValueError: If `max_seq_len` cannot be determined and is required.
|
107
|
+
ValueError: If attention parameter constraints are violated (e.g., `max_seq_len` vs
|
108
|
+
`kvcache_partition_len` for flash attention).
|
109
|
+
|
110
|
+
|
111
|
+
Attention Implementation:
|
112
|
+
`attn_impl` determines the underlying attention mechanism used by the model.
|
113
|
+
|
114
|
+
- **`"eager"`** (Default if `kvcache_partition_len` is not set): Uses the standard PyTorch
|
115
|
+
attention implementation. Suitable for sequences up to a certain limit (e.g., 32,768 tokens).
|
116
|
+
- **`"flash_attn"`**: Utilizes an optimized Flash Attention implementation, beneficial for
|
117
|
+
longer sequences and potentially faster execution. Requires `max_seq_len` to be at least
|
118
|
+
8,192. If `kvcache_partition_len` is specified, `attn_impl` automatically defaults
|
119
|
+
to `"flash_attn"`. When using `"flash_attn"`, `kvcache_block_size` must equal
|
120
|
+
`kvcache_partition_len`.
|
121
|
+
|
122
|
+
The choice impacts performance and memory usage, especially for long sequences.
|
123
|
+
Constraints related to `max_seq_len` and `kvcache_partition_len` apply when using
|
124
|
+
`"flash_attn"`.
|
125
|
+
|
126
|
+
|
127
|
+
KV Cache Partition Length:
|
128
|
+
`kvcache_partition_len` is relevant **only** when `attn_impl` is `"flash_attn"`.
|
129
|
+
|
130
|
+
- It defines the length (number of tokens) of each partition within the Key-Value (KV) cache.
|
131
|
+
- Must be between 4,096 and 32,768 (inclusive).
|
132
|
+
- When using `"flash_attn"`, `max_seq_len` must be a multiple of `kvcache_partition_len`
|
133
|
+
and at least twice its value (`max_seq_len >= 2 * kvcache_partition_len`).
|
134
|
+
- If `attn_impl` is `"flash_attn"` and `kvcache_partition_len` is `None`, it defaults to
|
135
|
+
16,384.
|
136
|
+
|
137
|
+
|
138
|
+
KV Cache Number of Blocks:
|
139
|
+
`kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
|
140
|
+
Each block holds `kvcache_block_size` tokens of Key and Value states.
|
141
|
+
|
142
|
+
- **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
|
143
|
+
the maximum number of blocks that can fit into the available RBLN device memory. This
|
144
|
+
calculation considers the model size (kernel memory), required buffer memory, the number
|
145
|
+
of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
|
146
|
+
This aims to maximize cache capacity for potentially better performance with long sequences
|
147
|
+
or larger batches without manual tuning.
|
148
|
+
- **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
|
149
|
+
but requires careful consideration of memory limits. Setting it too high may lead to
|
150
|
+
compilation errors if it exceeds available memory. The system will issue warnings if your
|
151
|
+
setting exceeds the estimated maximum.
|
152
|
+
- **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
|
153
|
+
which is beneficial for tasks involving many long sequences or large batch sizes, enabling
|
154
|
+
higher throughput. However, allocating more blocks consumes more memory.
|
155
|
+
- **Minimum Requirement**: The system requires a minimum number of blocks to function,
|
156
|
+
calculated based on `max_seq_len`, `kvcache_block_size`, and `batch_size`. The number of
|
157
|
+
allocated blocks must be sufficient to hold at least one full sequence length per item
|
158
|
+
in the batch concurrently. The system will log warnings or raise errors if constraints
|
159
|
+
are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
|
160
|
+
|
161
|
+
The optimal value depends on the specific model, task, hardware, and desired trade-off
|
162
|
+
between performance and memory usage. The automatic estimation provides a robust starting point.
|
71
163
|
"""
|
164
|
+
|
72
165
|
super().__init__(**kwargs)
|
73
166
|
self.batch_size = batch_size or 1
|
74
167
|
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
@@ -121,6 +214,10 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
|
|
121
214
|
# Larger batch size should be at the beginning of the list.
|
122
215
|
self.decoder_batch_sizes.sort(reverse=True)
|
123
216
|
|
217
|
+
self.cache_impl = cache_impl or "static"
|
218
|
+
self.sliding_window = sliding_window
|
219
|
+
self.sliding_window_layers = sliding_window_layers or []
|
220
|
+
|
124
221
|
@property
|
125
222
|
def use_multiple_decoder(self):
|
126
223
|
return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
|
@@ -21,6 +21,7 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
21
21
|
|
22
22
|
from ....utils import logging
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
24
|
+
from .configuration_decoderonly import CacheImplType
|
24
25
|
|
25
26
|
|
26
27
|
logger = logging.get_logger(__name__)
|
@@ -30,6 +31,7 @@ DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
30
31
|
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
31
32
|
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
32
33
|
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
34
|
+
MAX_SLIDING_WINDOW_SIZE = 32_768
|
33
35
|
|
34
36
|
|
35
37
|
def set_default_values(
|
@@ -114,6 +116,13 @@ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcach
|
|
114
116
|
)
|
115
117
|
|
116
118
|
|
119
|
+
def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
|
120
|
+
if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
|
121
|
+
raise ValueError(
|
122
|
+
f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
|
123
|
+
)
|
124
|
+
|
125
|
+
|
117
126
|
class DecoderOnlyWrapper(nn.Module):
|
118
127
|
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
119
128
|
|
@@ -146,12 +155,15 @@ class DecoderOnlyWrapper(nn.Module):
|
|
146
155
|
max_seq_len: int,
|
147
156
|
use_rotary_emb: bool,
|
148
157
|
attn_impl: str,
|
158
|
+
cache_impl: CacheImplType,
|
149
159
|
use_inputs_embeds: bool,
|
150
160
|
use_attention_mask: bool,
|
151
161
|
use_position_ids: bool,
|
152
162
|
use_learned_pos_emb: Optional[bool] = None,
|
153
163
|
kvcache_partition_len: Optional[int] = None,
|
154
164
|
kvcache_block_size: Optional[int] = None,
|
165
|
+
sliding_window: Optional[int] = None,
|
166
|
+
sliding_window_layers: Optional[List[int]] = None,
|
155
167
|
):
|
156
168
|
super().__init__()
|
157
169
|
self.config = causal_lm.config
|
@@ -171,6 +183,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
171
183
|
self.use_position_ids = use_position_ids
|
172
184
|
self.use_inputs_embeds = use_inputs_embeds
|
173
185
|
self.use_learned_pos_emb = use_learned_pos_emb
|
186
|
+
self.sliding_window_layers = sliding_window_layers
|
187
|
+
self.cache_impl = cache_impl
|
188
|
+
self.sliding_window = sliding_window
|
174
189
|
|
175
190
|
if self.attn_impl == "flash_attn":
|
176
191
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
@@ -186,7 +201,6 @@ class DecoderOnlyWrapper(nn.Module):
|
|
186
201
|
)
|
187
202
|
|
188
203
|
self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
|
189
|
-
|
190
204
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
191
205
|
self._phase = "prefill"
|
192
206
|
|
@@ -195,25 +209,35 @@ class DecoderOnlyWrapper(nn.Module):
|
|
195
209
|
|
196
210
|
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
|
197
211
|
new_layers = []
|
198
|
-
|
199
|
-
|
200
|
-
|
212
|
+
for layer_idx, layer in enumerate(causal_lm.model.layers):
|
213
|
+
if layer_idx in self.sliding_window_layers:
|
214
|
+
# Flash attention is not yet supported for sliding window attention.
|
201
215
|
new_self_attn = DecoderOnlyAttention(
|
202
216
|
layer.self_attn,
|
203
217
|
self.use_attention_mask,
|
204
218
|
self.use_position_ids,
|
205
|
-
kvcache_block_size=self.
|
206
|
-
|
207
|
-
elif self.attn_impl == "flash_attn":
|
208
|
-
new_self_attn = DecoderOnlyFlashAttention(
|
209
|
-
layer.self_attn,
|
210
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
211
|
-
kvcache_block_size=self.kvcache_block_size,
|
212
|
-
use_attention_mask=self.use_attention_mask,
|
213
|
-
use_position_ids=self.use_position_ids,
|
219
|
+
kvcache_block_size=self.sliding_window,
|
220
|
+
is_sliding=True,
|
214
221
|
)
|
215
222
|
else:
|
216
|
-
|
223
|
+
if self.attn_impl == "eager":
|
224
|
+
new_self_attn = DecoderOnlyAttention(
|
225
|
+
layer.self_attn,
|
226
|
+
self.use_attention_mask,
|
227
|
+
self.use_position_ids,
|
228
|
+
kvcache_block_size=self.kvcache_block_size,
|
229
|
+
is_sliding=False,
|
230
|
+
)
|
231
|
+
elif self.attn_impl == "flash_attn":
|
232
|
+
new_self_attn = DecoderOnlyFlashAttention(
|
233
|
+
layer.self_attn,
|
234
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
235
|
+
kvcache_block_size=self.kvcache_block_size,
|
236
|
+
use_attention_mask=self.use_attention_mask,
|
237
|
+
use_position_ids=self.use_position_ids,
|
238
|
+
)
|
239
|
+
else:
|
240
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
217
241
|
|
218
242
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
219
243
|
new_layers.append(new_layer)
|
@@ -225,6 +249,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
225
249
|
max_seq_len=max_seq_len,
|
226
250
|
kvcache_block_size=self.kvcache_block_size,
|
227
251
|
use_learned_pos_emb=self.use_learned_pos_emb,
|
252
|
+
sliding_window_layers=self.sliding_window_layers,
|
228
253
|
)
|
229
254
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
230
255
|
return new_causal_lm
|
@@ -243,8 +268,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
243
268
|
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
244
269
|
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
245
270
|
cache_position = args.pop(0)
|
246
|
-
|
247
|
-
|
271
|
+
global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
|
272
|
+
local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
|
273
|
+
query_position = args.pop(0) if "prefill" in self.phase else None
|
248
274
|
attention_mask = args.pop(0) if self.use_attention_mask else None
|
249
275
|
position_ids = args.pop(0) if self.use_position_ids else None
|
250
276
|
past_key_values = args
|
@@ -264,16 +290,22 @@ class DecoderOnlyWrapper(nn.Module):
|
|
264
290
|
_past_key_values.append(past_key_value)
|
265
291
|
past_key_values = _past_key_values
|
266
292
|
|
293
|
+
if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
|
294
|
+
rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
|
295
|
+
else:
|
296
|
+
rotary_emb = self.rotary_emb
|
297
|
+
|
267
298
|
return (
|
268
299
|
input_ids,
|
269
300
|
inputs_embeds,
|
270
301
|
cache_position,
|
271
|
-
|
302
|
+
global_block_tables,
|
303
|
+
local_block_tables,
|
272
304
|
query_position,
|
273
305
|
attention_mask,
|
274
306
|
position_ids,
|
275
307
|
past_key_values,
|
276
|
-
|
308
|
+
rotary_emb,
|
277
309
|
)
|
278
310
|
|
279
311
|
def forward(self, *args):
|
@@ -281,7 +313,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
281
313
|
input_ids,
|
282
314
|
inputs_embeds,
|
283
315
|
cache_position,
|
284
|
-
|
316
|
+
global_block_tables,
|
317
|
+
local_block_tables,
|
285
318
|
query_position,
|
286
319
|
attention_mask,
|
287
320
|
position_ids,
|
@@ -298,7 +331,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
298
331
|
query_position=query_position,
|
299
332
|
past_key_values=past_key_values,
|
300
333
|
rotary_emb=rotary_emb,
|
301
|
-
|
334
|
+
global_block_tables=global_block_tables,
|
335
|
+
local_block_tables=local_block_tables,
|
302
336
|
)
|
303
337
|
|
304
338
|
return logit
|
@@ -353,7 +387,8 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
353
387
|
query_position: torch.Tensor = None,
|
354
388
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
355
389
|
rotary_emb: nn.Module = None,
|
356
|
-
|
390
|
+
global_block_tables: Optional[torch.Tensor] = None,
|
391
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
357
392
|
):
|
358
393
|
# outputs
|
359
394
|
hidden_states = self.model(
|
@@ -362,12 +397,14 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
362
397
|
attention_mask=attention_mask,
|
363
398
|
cache_position=cache_position,
|
364
399
|
position_ids=position_ids,
|
400
|
+
query_position=query_position,
|
365
401
|
past_key_values=past_key_values,
|
366
402
|
rotary_emb=rotary_emb,
|
367
|
-
|
403
|
+
global_block_tables=global_block_tables,
|
404
|
+
local_block_tables=local_block_tables,
|
368
405
|
)
|
369
406
|
|
370
|
-
if self.phase
|
407
|
+
if "prefill" in self.phase:
|
371
408
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
372
409
|
|
373
410
|
logits = self.lm_head(hidden_states)
|
@@ -402,6 +439,7 @@ class DecoderOnlyModel(nn.Module):
|
|
402
439
|
max_seq_len=None,
|
403
440
|
kvcache_block_size=None,
|
404
441
|
use_learned_pos_emb=None,
|
442
|
+
sliding_window_layers=None,
|
405
443
|
):
|
406
444
|
super().__init__()
|
407
445
|
self._original_mod = model
|
@@ -411,6 +449,7 @@ class DecoderOnlyModel(nn.Module):
|
|
411
449
|
self.kvcache_block_size = kvcache_block_size
|
412
450
|
self.max_seq_len = max_seq_len
|
413
451
|
self.use_learned_pos_emb = use_learned_pos_emb
|
452
|
+
self.sliding_window_layers = sliding_window_layers
|
414
453
|
|
415
454
|
@property
|
416
455
|
def phase(self):
|
@@ -441,6 +480,16 @@ class DecoderOnlyModel(nn.Module):
|
|
441
480
|
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
442
481
|
return cache_pos_for_partitions
|
443
482
|
|
483
|
+
def get_local_cache_positions(self, position_ids, query_position):
|
484
|
+
max_cache_len = self._original_mod.config.sliding_window
|
485
|
+
valid_input_len = 1 if query_position is None else query_position + 1
|
486
|
+
cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
|
487
|
+
cache_offset = (
|
488
|
+
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
|
489
|
+
) # cache offset for next steps
|
490
|
+
|
491
|
+
return cache_seq_len, cache_offset
|
492
|
+
|
444
493
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
445
494
|
return self._original_mod.norm
|
446
495
|
|
@@ -459,9 +508,11 @@ class DecoderOnlyModel(nn.Module):
|
|
459
508
|
attention_mask: torch.Tensor = None,
|
460
509
|
cache_position: torch.Tensor = None,
|
461
510
|
position_ids: torch.Tensor = None,
|
511
|
+
query_position: torch.Tensor = None,
|
462
512
|
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
463
513
|
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
464
|
-
|
514
|
+
global_block_tables: Optional[torch.Tensor] = None,
|
515
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
465
516
|
):
|
466
517
|
# retrieve input_ids and inputs_embeds
|
467
518
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
@@ -511,7 +562,7 @@ class DecoderOnlyModel(nn.Module):
|
|
511
562
|
hidden_states = hidden_states + position_embeds
|
512
563
|
cos, sin = None, None
|
513
564
|
|
514
|
-
#
|
565
|
+
# Get sequence positions for flash attention
|
515
566
|
if self.attn_impl == "flash_attn":
|
516
567
|
seq_positions = cache_position[:, 0]
|
517
568
|
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
@@ -520,15 +571,20 @@ class DecoderOnlyModel(nn.Module):
|
|
520
571
|
else:
|
521
572
|
seq_positions = cache_position[:, :1]
|
522
573
|
|
523
|
-
for
|
574
|
+
# Get local cache positions for sliding window layers
|
575
|
+
if len(self.sliding_window_layers) > 0:
|
576
|
+
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
577
|
+
|
578
|
+
for layer_idx, layer in enumerate(self.layers):
|
579
|
+
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
524
580
|
hidden_states = layer(
|
525
581
|
hidden_states=hidden_states,
|
526
582
|
attention_mask=attention_mask,
|
527
|
-
seq_positions=seq_positions,
|
583
|
+
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
528
584
|
past_key_values=past_key_values,
|
529
585
|
cos=cos,
|
530
586
|
sin=sin,
|
531
|
-
block_tables=
|
587
|
+
block_tables=local_block_tables if is_sliding else global_block_tables,
|
532
588
|
)
|
533
589
|
|
534
590
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
@@ -625,7 +681,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
625
681
|
self_attn: Original attention module from the base model
|
626
682
|
"""
|
627
683
|
|
628
|
-
def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
|
684
|
+
def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size, is_sliding=False):
|
629
685
|
super().__init__()
|
630
686
|
self._original_mod = self_attn
|
631
687
|
self.layer_idx = self_attn.layer_idx
|
@@ -645,6 +701,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
645
701
|
|
646
702
|
self.use_attention_mask = use_attention_mask
|
647
703
|
self.use_position_ids = use_position_ids
|
704
|
+
self.is_sliding = is_sliding
|
648
705
|
self.attention = self.get_attention()
|
649
706
|
self.kvcache_block_size = kvcache_block_size
|
650
707
|
self.__post_init__()
|
@@ -659,9 +716,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
659
716
|
self.attention.phase = phase
|
660
717
|
|
661
718
|
def get_attention(self):
|
662
|
-
|
663
|
-
|
664
|
-
|
719
|
+
if self.is_sliding:
|
720
|
+
return SlidingWindowAttentionOp(
|
721
|
+
self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
|
722
|
+
)
|
723
|
+
else:
|
724
|
+
return AttentionOp(
|
725
|
+
self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
|
726
|
+
)
|
665
727
|
|
666
728
|
def __post_init__(self):
|
667
729
|
self.q_proj = self._original_mod.q_proj
|
@@ -708,12 +770,14 @@ class DecoderOnlyAttention(nn.Module):
|
|
708
770
|
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
709
771
|
1, 2
|
710
772
|
)
|
711
|
-
|
773
|
+
if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
|
774
|
+
query_states = self.q_norm(query_states)
|
775
|
+
key_states = self.k_norm(key_states)
|
712
776
|
|
713
777
|
if cos is not None and sin is not None:
|
714
778
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
715
779
|
|
716
|
-
if batch_size > 1 and self.phase
|
780
|
+
if batch_size > 1 and "prefill" in self.phase:
|
717
781
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
718
782
|
|
719
783
|
attn_output = self.attention(
|
@@ -987,7 +1051,10 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
987
1051
|
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
988
1052
|
1, 2
|
989
1053
|
)
|
990
|
-
|
1054
|
+
|
1055
|
+
if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
|
1056
|
+
query_states = self.q_norm(query_states)
|
1057
|
+
key_states = self.k_norm(key_states)
|
991
1058
|
|
992
1059
|
if cos is not None and sin is not None:
|
993
1060
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|