optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +53 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  54. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  55. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  56. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  57. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  58. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  59. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
  60. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  63. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  64. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  65. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  66. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  67. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  68. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  69. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  70. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  71. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
  72. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  73. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  74. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  75. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  76. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  77. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  78. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  79. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  80. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  81. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  82. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  83. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  84. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  85. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  86. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  87. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  91. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  92. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  93. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  94. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
  95. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
  96. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  97. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  98. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  99. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  100. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  101. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  102. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  103. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  104. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  105. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  106. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  108. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  109. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  110. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  111. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  112. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  113. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  114. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  115. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  116. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  117. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  118. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  119. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  123. optimum/rbln/utils/model_utils.py +20 -0
  124. optimum/rbln/utils/submodule.py +6 -8
  125. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
  126. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
  127. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  128. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  129. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
  130. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Union
15
+ from typing import Any, Dict, List, Literal, Optional, Union
16
16
 
17
17
  import rebel
18
18
 
@@ -23,8 +23,18 @@ from ...utils.rbln_quantization import RBLNQuantizationConfig
23
23
 
24
24
  logger = get_logger()
25
25
 
26
+ CacheImplType = Literal["static", "sliding_window", "hybrid"]
27
+
26
28
 
27
29
  class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
30
+ """
31
+ Configuration class for RBLN decoder-only models for Causal Language Modeling.
32
+
33
+ This class extends RBLNModelConfig with parameters specific to decoder-only transformer
34
+ architectures optimized for RBLN devices. It controls aspects like attention implementation,
35
+ KV cache management, and batching for inference.
36
+ """
37
+
28
38
  def __init__(
29
39
  self,
30
40
  batch_size: Optional[int] = None,
@@ -39,36 +49,119 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
39
49
  prefill_chunk_size: Optional[int] = None,
40
50
  kvcache_num_blocks: Optional[int] = None,
41
51
  decoder_batch_sizes: Optional[List[int]] = None,
52
+ cache_impl: Optional[CacheImplType] = None,
53
+ sliding_window: Optional[int] = None,
54
+ sliding_window_layers: Optional[List[int]] = None,
42
55
  **kwargs,
43
56
  ):
44
57
  """
45
58
  Args:
46
59
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
47
60
  max_seq_len (Optional[int]): The maximum sequence length supported by the model.
48
- use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
49
- use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
50
- for RBLN-CA02 devices.
61
+ If not provided, it attempts to infer from the model's configuration
62
+ (`max_position_embeddings` or `n_positions`). Must be specified if not available
63
+ in the model config.
64
+ use_inputs_embeds (Optional[bool]): Whether to use input embeddings (`inputs_embeds`)
65
+ directly instead of `input_ids`. Defaults to False. Requires the model to be
66
+ compiled with this option enabled.
67
+ use_attention_mask (Optional[bool]): Whether the model requires attention masks during
68
+ inference. This is typically determined based on the target device and model
69
+ architecture. Defaults are often set automatically based on the model and RBLN NPU.
51
70
  use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
52
- attn_impl (Optional[str]): The attention implementation to use.
53
- kvcache_partition_len (Optional[int]): The length of each KV cache partition.
54
- kvcache_block_size (Optional[int]): The block size for KV cache.
55
- quantization (Optional[Dict[str, Any]]): Configuration for model quantization.
56
- prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
57
- and must be a positive integer divisible by 64.
58
- kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
71
+ attn_impl (Optional[str]): Specifies the attention implementation to use.
72
+ See the "Attention Implementation (`attn_impl`)" section below for details.
73
+ kvcache_partition_len (Optional[int]): Defines the partition length for the KV cache
74
+ when using "flash_attn". See the "KV Cache Partition Length (`kvcache_partition_len`)"
75
+ section below for details.
76
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
77
+ in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
78
+ section below for details.
79
+ quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
80
+ quantization. Specifies format, etc.
81
+ prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
82
+ processing input sequences. Defaults to 128. Must be a positive integer
83
+ divisible by 64. Affects prefill performance and memory usage.
84
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
85
+ PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
86
+ section below for details.
59
87
  decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
60
88
  This allows the model to handle varying batch sizes efficiently during generation. If not specified,
61
89
  defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
62
90
  1) All values must be less than or equal to the main batch size.
63
91
  2) The list will be sorted in descending order (larger batch sizes first).
64
92
  3) If using multiple decoders, at least one batch size should match the main batch size.
65
-
93
+ cache_impl (Optional[CacheImplType]): Specifies the KV cache implementation strategy. Defaults to "static".
94
+ - "static": Uses a fixed-size global KV cache for all layers, suitable for standard attention patterns.
95
+ - "sliding_window": Implements a sliding window KV cache, where each layer maintains a local cache of recent tokens.
96
+ - "hybrid": Combines both static and sliding window approaches, allowing different layers to use different cache strategies.
97
+ The choice affects memory usage and attention patterns. When using "sliding_window" or "hybrid",
98
+ you must specify the `sliding_window` size and optionally `sliding_window_layers` for hybrid mode.
99
+ sliding_window (Optional[int]): The size of the sliding window. Defaults to None.
100
+ sliding_window_layers (Optional[List[int]]): The layers to use for the sliding window used in the hybrid model. Defaults to None.
66
101
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
67
102
 
68
103
  Raises:
69
- ValueError: If batch_size is not a positive integer or if prefill_chunk_size is not
70
- a positive integer divisible by 64.
104
+ ValueError: If `batch_size` is not a positive integer.
105
+ ValueError: If `prefill_chunk_size` is not a positive integer divisible by 64.
106
+ ValueError: If `max_seq_len` cannot be determined and is required.
107
+ ValueError: If attention parameter constraints are violated (e.g., `max_seq_len` vs
108
+ `kvcache_partition_len` for flash attention).
109
+
110
+
111
+ Attention Implementation:
112
+ `attn_impl` determines the underlying attention mechanism used by the model.
113
+
114
+ - **`"eager"`** (Default if `kvcache_partition_len` is not set): Uses the standard PyTorch
115
+ attention implementation. Suitable for sequences up to a certain limit (e.g., 32,768 tokens).
116
+ - **`"flash_attn"`**: Utilizes an optimized Flash Attention implementation, beneficial for
117
+ longer sequences and potentially faster execution. Requires `max_seq_len` to be at least
118
+ 8,192. If `kvcache_partition_len` is specified, `attn_impl` automatically defaults
119
+ to `"flash_attn"`. When using `"flash_attn"`, `kvcache_block_size` must equal
120
+ `kvcache_partition_len`.
121
+
122
+ The choice impacts performance and memory usage, especially for long sequences.
123
+ Constraints related to `max_seq_len` and `kvcache_partition_len` apply when using
124
+ `"flash_attn"`.
125
+
126
+
127
+ KV Cache Partition Length:
128
+ `kvcache_partition_len` is relevant **only** when `attn_impl` is `"flash_attn"`.
129
+
130
+ - It defines the length (number of tokens) of each partition within the Key-Value (KV) cache.
131
+ - Must be between 4,096 and 32,768 (inclusive).
132
+ - When using `"flash_attn"`, `max_seq_len` must be a multiple of `kvcache_partition_len`
133
+ and at least twice its value (`max_seq_len >= 2 * kvcache_partition_len`).
134
+ - If `attn_impl` is `"flash_attn"` and `kvcache_partition_len` is `None`, it defaults to
135
+ 16,384.
136
+
137
+
138
+ KV Cache Number of Blocks:
139
+ `kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
140
+ Each block holds `kvcache_block_size` tokens of Key and Value states.
141
+
142
+ - **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
143
+ the maximum number of blocks that can fit into the available RBLN device memory. This
144
+ calculation considers the model size (kernel memory), required buffer memory, the number
145
+ of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
146
+ This aims to maximize cache capacity for potentially better performance with long sequences
147
+ or larger batches without manual tuning.
148
+ - **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
149
+ but requires careful consideration of memory limits. Setting it too high may lead to
150
+ compilation errors if it exceeds available memory. The system will issue warnings if your
151
+ setting exceeds the estimated maximum.
152
+ - **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
153
+ which is beneficial for tasks involving many long sequences or large batch sizes, enabling
154
+ higher throughput. However, allocating more blocks consumes more memory.
155
+ - **Minimum Requirement**: The system requires a minimum number of blocks to function,
156
+ calculated based on `max_seq_len`, `kvcache_block_size`, and `batch_size`. The number of
157
+ allocated blocks must be sufficient to hold at least one full sequence length per item
158
+ in the batch concurrently. The system will log warnings or raise errors if constraints
159
+ are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
160
+
161
+ The optimal value depends on the specific model, task, hardware, and desired trade-off
162
+ between performance and memory usage. The automatic estimation provides a robust starting point.
71
163
  """
164
+
72
165
  super().__init__(**kwargs)
73
166
  self.batch_size = batch_size or 1
74
167
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
@@ -121,6 +214,10 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
121
214
  # Larger batch size should be at the beginning of the list.
122
215
  self.decoder_batch_sizes.sort(reverse=True)
123
216
 
217
+ self.cache_impl = cache_impl or "static"
218
+ self.sliding_window = sliding_window
219
+ self.sliding_window_layers = sliding_window_layers or []
220
+
124
221
  @property
125
222
  def use_multiple_decoder(self):
126
223
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
@@ -21,6 +21,7 @@ from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
+ from .configuration_decoderonly import CacheImplType
24
25
 
25
26
 
26
27
  logger = logging.get_logger(__name__)
@@ -30,6 +31,7 @@ DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
30
31
  MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
31
32
  MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
32
33
  MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
34
+ MAX_SLIDING_WINDOW_SIZE = 32_768
33
35
 
34
36
 
35
37
  def set_default_values(
@@ -114,6 +116,13 @@ def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcach
114
116
  )
115
117
 
116
118
 
119
+ def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
120
+ if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
121
+ raise ValueError(
122
+ f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
123
+ )
124
+
125
+
117
126
  class DecoderOnlyWrapper(nn.Module):
118
127
  """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
119
128
 
@@ -146,12 +155,15 @@ class DecoderOnlyWrapper(nn.Module):
146
155
  max_seq_len: int,
147
156
  use_rotary_emb: bool,
148
157
  attn_impl: str,
158
+ cache_impl: CacheImplType,
149
159
  use_inputs_embeds: bool,
150
160
  use_attention_mask: bool,
151
161
  use_position_ids: bool,
152
162
  use_learned_pos_emb: Optional[bool] = None,
153
163
  kvcache_partition_len: Optional[int] = None,
154
164
  kvcache_block_size: Optional[int] = None,
165
+ sliding_window: Optional[int] = None,
166
+ sliding_window_layers: Optional[List[int]] = None,
155
167
  ):
156
168
  super().__init__()
157
169
  self.config = causal_lm.config
@@ -171,6 +183,9 @@ class DecoderOnlyWrapper(nn.Module):
171
183
  self.use_position_ids = use_position_ids
172
184
  self.use_inputs_embeds = use_inputs_embeds
173
185
  self.use_learned_pos_emb = use_learned_pos_emb
186
+ self.sliding_window_layers = sliding_window_layers
187
+ self.cache_impl = cache_impl
188
+ self.sliding_window = sliding_window
174
189
 
175
190
  if self.attn_impl == "flash_attn":
176
191
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
@@ -186,7 +201,6 @@ class DecoderOnlyWrapper(nn.Module):
186
201
  )
187
202
 
188
203
  self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
189
-
190
204
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
191
205
  self._phase = "prefill"
192
206
 
@@ -195,25 +209,35 @@ class DecoderOnlyWrapper(nn.Module):
195
209
 
196
210
  def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
197
211
  new_layers = []
198
-
199
- for layer in causal_lm.model.layers:
200
- if self.attn_impl == "eager":
212
+ for layer_idx, layer in enumerate(causal_lm.model.layers):
213
+ if layer_idx in self.sliding_window_layers:
214
+ # Flash attention is not yet supported for sliding window attention.
201
215
  new_self_attn = DecoderOnlyAttention(
202
216
  layer.self_attn,
203
217
  self.use_attention_mask,
204
218
  self.use_position_ids,
205
- kvcache_block_size=self.kvcache_block_size,
206
- )
207
- elif self.attn_impl == "flash_attn":
208
- new_self_attn = DecoderOnlyFlashAttention(
209
- layer.self_attn,
210
- kvcache_partition_len=self.kvcache_partition_len,
211
- kvcache_block_size=self.kvcache_block_size,
212
- use_attention_mask=self.use_attention_mask,
213
- use_position_ids=self.use_position_ids,
219
+ kvcache_block_size=self.sliding_window,
220
+ is_sliding=True,
214
221
  )
215
222
  else:
216
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
223
+ if self.attn_impl == "eager":
224
+ new_self_attn = DecoderOnlyAttention(
225
+ layer.self_attn,
226
+ self.use_attention_mask,
227
+ self.use_position_ids,
228
+ kvcache_block_size=self.kvcache_block_size,
229
+ is_sliding=False,
230
+ )
231
+ elif self.attn_impl == "flash_attn":
232
+ new_self_attn = DecoderOnlyFlashAttention(
233
+ layer.self_attn,
234
+ kvcache_partition_len=self.kvcache_partition_len,
235
+ kvcache_block_size=self.kvcache_block_size,
236
+ use_attention_mask=self.use_attention_mask,
237
+ use_position_ids=self.use_position_ids,
238
+ )
239
+ else:
240
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
217
241
 
218
242
  new_layer = DecoderOnlyLayer(layer, new_self_attn)
219
243
  new_layers.append(new_layer)
@@ -225,6 +249,7 @@ class DecoderOnlyWrapper(nn.Module):
225
249
  max_seq_len=max_seq_len,
226
250
  kvcache_block_size=self.kvcache_block_size,
227
251
  use_learned_pos_emb=self.use_learned_pos_emb,
252
+ sliding_window_layers=self.sliding_window_layers,
228
253
  )
229
254
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
230
255
  return new_causal_lm
@@ -243,8 +268,9 @@ class DecoderOnlyWrapper(nn.Module):
243
268
  input_ids = None if self.use_inputs_embeds else args.pop(0)
244
269
  inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
245
270
  cache_position = args.pop(0)
246
- block_tables = args.pop(0)
247
- query_position = args.pop(0) if self.phase == "prefill" else None
271
+ global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
272
+ local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
273
+ query_position = args.pop(0) if "prefill" in self.phase else None
248
274
  attention_mask = args.pop(0) if self.use_attention_mask else None
249
275
  position_ids = args.pop(0) if self.use_position_ids else None
250
276
  past_key_values = args
@@ -264,16 +290,22 @@ class DecoderOnlyWrapper(nn.Module):
264
290
  _past_key_values.append(past_key_value)
265
291
  past_key_values = _past_key_values
266
292
 
293
+ if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
294
+ rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
295
+ else:
296
+ rotary_emb = self.rotary_emb
297
+
267
298
  return (
268
299
  input_ids,
269
300
  inputs_embeds,
270
301
  cache_position,
271
- block_tables,
302
+ global_block_tables,
303
+ local_block_tables,
272
304
  query_position,
273
305
  attention_mask,
274
306
  position_ids,
275
307
  past_key_values,
276
- self.rotary_emb,
308
+ rotary_emb,
277
309
  )
278
310
 
279
311
  def forward(self, *args):
@@ -281,7 +313,8 @@ class DecoderOnlyWrapper(nn.Module):
281
313
  input_ids,
282
314
  inputs_embeds,
283
315
  cache_position,
284
- block_tables,
316
+ global_block_tables,
317
+ local_block_tables,
285
318
  query_position,
286
319
  attention_mask,
287
320
  position_ids,
@@ -298,7 +331,8 @@ class DecoderOnlyWrapper(nn.Module):
298
331
  query_position=query_position,
299
332
  past_key_values=past_key_values,
300
333
  rotary_emb=rotary_emb,
301
- block_tables=block_tables,
334
+ global_block_tables=global_block_tables,
335
+ local_block_tables=local_block_tables,
302
336
  )
303
337
 
304
338
  return logit
@@ -353,7 +387,8 @@ class DecoderOnlyForCausalLM(nn.Module):
353
387
  query_position: torch.Tensor = None,
354
388
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
355
389
  rotary_emb: nn.Module = None,
356
- block_tables: Optional[torch.Tensor] = None,
390
+ global_block_tables: Optional[torch.Tensor] = None,
391
+ local_block_tables: Optional[torch.Tensor] = None,
357
392
  ):
358
393
  # outputs
359
394
  hidden_states = self.model(
@@ -362,12 +397,14 @@ class DecoderOnlyForCausalLM(nn.Module):
362
397
  attention_mask=attention_mask,
363
398
  cache_position=cache_position,
364
399
  position_ids=position_ids,
400
+ query_position=query_position,
365
401
  past_key_values=past_key_values,
366
402
  rotary_emb=rotary_emb,
367
- block_tables=block_tables,
403
+ global_block_tables=global_block_tables,
404
+ local_block_tables=local_block_tables,
368
405
  )
369
406
 
370
- if self.phase == "prefill":
407
+ if "prefill" in self.phase:
371
408
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
372
409
 
373
410
  logits = self.lm_head(hidden_states)
@@ -402,6 +439,7 @@ class DecoderOnlyModel(nn.Module):
402
439
  max_seq_len=None,
403
440
  kvcache_block_size=None,
404
441
  use_learned_pos_emb=None,
442
+ sliding_window_layers=None,
405
443
  ):
406
444
  super().__init__()
407
445
  self._original_mod = model
@@ -411,6 +449,7 @@ class DecoderOnlyModel(nn.Module):
411
449
  self.kvcache_block_size = kvcache_block_size
412
450
  self.max_seq_len = max_seq_len
413
451
  self.use_learned_pos_emb = use_learned_pos_emb
452
+ self.sliding_window_layers = sliding_window_layers
414
453
 
415
454
  @property
416
455
  def phase(self):
@@ -441,6 +480,16 @@ class DecoderOnlyModel(nn.Module):
441
480
  cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
442
481
  return cache_pos_for_partitions
443
482
 
483
+ def get_local_cache_positions(self, position_ids, query_position):
484
+ max_cache_len = self._original_mod.config.sliding_window
485
+ valid_input_len = 1 if query_position is None else query_position + 1
486
+ cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
487
+ cache_offset = (
488
+ torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
489
+ ) # cache offset for next steps
490
+
491
+ return cache_seq_len, cache_offset
492
+
444
493
  def get_last_layernorm(self) -> nn.LayerNorm:
445
494
  return self._original_mod.norm
446
495
 
@@ -459,9 +508,11 @@ class DecoderOnlyModel(nn.Module):
459
508
  attention_mask: torch.Tensor = None,
460
509
  cache_position: torch.Tensor = None,
461
510
  position_ids: torch.Tensor = None,
511
+ query_position: torch.Tensor = None,
462
512
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
463
513
  rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
464
- block_tables: Optional[torch.Tensor] = None,
514
+ global_block_tables: Optional[torch.Tensor] = None,
515
+ local_block_tables: Optional[torch.Tensor] = None,
465
516
  ):
466
517
  # retrieve input_ids and inputs_embeds
467
518
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -511,7 +562,7 @@ class DecoderOnlyModel(nn.Module):
511
562
  hidden_states = hidden_states + position_embeds
512
563
  cos, sin = None, None
513
564
 
514
- # (batch, seq_len) -> (batch,)
565
+ # Get sequence positions for flash attention
515
566
  if self.attn_impl == "flash_attn":
516
567
  seq_positions = cache_position[:, 0]
517
568
  seq_positions = self.convert_sequence_positions_for_flash_attn(
@@ -520,15 +571,20 @@ class DecoderOnlyModel(nn.Module):
520
571
  else:
521
572
  seq_positions = cache_position[:, :1]
522
573
 
523
- for layer in self.layers:
574
+ # Get local cache positions for sliding window layers
575
+ if len(self.sliding_window_layers) > 0:
576
+ sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
577
+
578
+ for layer_idx, layer in enumerate(self.layers):
579
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
524
580
  hidden_states = layer(
525
581
  hidden_states=hidden_states,
526
582
  attention_mask=attention_mask,
527
- seq_positions=seq_positions,
583
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
528
584
  past_key_values=past_key_values,
529
585
  cos=cos,
530
586
  sin=sin,
531
- block_tables=block_tables,
587
+ block_tables=local_block_tables if is_sliding else global_block_tables,
532
588
  )
533
589
 
534
590
  hidden_states = self.get_last_layernorm()(hidden_states)
@@ -625,7 +681,7 @@ class DecoderOnlyAttention(nn.Module):
625
681
  self_attn: Original attention module from the base model
626
682
  """
627
683
 
628
- def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size):
684
+ def __init__(self, self_attn, use_attention_mask, use_position_ids, kvcache_block_size, is_sliding=False):
629
685
  super().__init__()
630
686
  self._original_mod = self_attn
631
687
  self.layer_idx = self_attn.layer_idx
@@ -645,6 +701,7 @@ class DecoderOnlyAttention(nn.Module):
645
701
 
646
702
  self.use_attention_mask = use_attention_mask
647
703
  self.use_position_ids = use_position_ids
704
+ self.is_sliding = is_sliding
648
705
  self.attention = self.get_attention()
649
706
  self.kvcache_block_size = kvcache_block_size
650
707
  self.__post_init__()
@@ -659,9 +716,14 @@ class DecoderOnlyAttention(nn.Module):
659
716
  self.attention.phase = phase
660
717
 
661
718
  def get_attention(self):
662
- return AttentionOp(
663
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
664
- )
719
+ if self.is_sliding:
720
+ return SlidingWindowAttentionOp(
721
+ self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
722
+ )
723
+ else:
724
+ return AttentionOp(
725
+ self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
726
+ )
665
727
 
666
728
  def __post_init__(self):
667
729
  self.q_proj = self._original_mod.q_proj
@@ -708,12 +770,14 @@ class DecoderOnlyAttention(nn.Module):
708
770
  value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
709
771
  1, 2
710
772
  )
711
- # b, num_head, query, head_dim
773
+ if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
774
+ query_states = self.q_norm(query_states)
775
+ key_states = self.k_norm(key_states)
712
776
 
713
777
  if cos is not None and sin is not None:
714
778
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
715
779
 
716
- if batch_size > 1 and self.phase == "prefill":
780
+ if batch_size > 1 and "prefill" in self.phase:
717
781
  raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
718
782
 
719
783
  attn_output = self.attention(
@@ -987,7 +1051,10 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
987
1051
  value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
988
1052
  1, 2
989
1053
  )
990
- # b, num_head, query, head_dim
1054
+
1055
+ if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
1056
+ query_states = self.q_norm(query_states)
1057
+ key_states = self.k_norm(key_states)
991
1058
 
992
1059
  if cos is not None and sin is not None:
993
1060
  query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)