optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -13,18 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import math
|
16
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
21
21
|
|
22
|
-
from ....ops import (
|
23
|
-
register_rbln_custom_paged_attention,
|
24
|
-
register_rbln_custom_paged_causal_attention,
|
25
|
-
register_rbln_custom_paged_flash_attention,
|
26
|
-
register_rbln_custom_paged_flash_causal_attention,
|
27
|
-
)
|
28
22
|
from ....utils import logging
|
29
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
30
24
|
|
@@ -38,30 +32,39 @@ MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
38
32
|
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
39
33
|
|
40
34
|
|
41
|
-
def
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
35
|
+
def set_default_values(
|
36
|
+
attn_impl: Optional[str] = None,
|
37
|
+
kvcache_partition_len: Optional[int] = None,
|
38
|
+
kvcache_block_size: Optional[int] = None,
|
39
|
+
max_seq_len: Optional[int] = None,
|
40
|
+
) -> Tuple[str, int, int]:
|
41
|
+
if attn_impl is None:
|
42
|
+
attn_impl = "eager"
|
43
|
+
|
44
|
+
if kvcache_partition_len is not None:
|
45
|
+
if attn_impl == "eager":
|
46
|
+
attn_impl = "flash_attn"
|
53
47
|
logger.warning(
|
54
|
-
"A non-null `
|
55
|
-
"Since KV cache partitioning is only supported with flash attention, "
|
56
|
-
"`
|
48
|
+
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
49
|
+
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
50
|
+
"`attn_impl` has been automatically switched to 'flash_attn'."
|
57
51
|
)
|
58
52
|
|
59
|
-
|
60
|
-
|
61
|
-
|
53
|
+
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
54
|
+
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
55
|
+
|
56
|
+
if kvcache_block_size is None:
|
57
|
+
if attn_impl == "eager":
|
58
|
+
kvcache_block_size = max_seq_len
|
59
|
+
else:
|
60
|
+
kvcache_block_size = kvcache_partition_len
|
61
|
+
|
62
|
+
return attn_impl, kvcache_partition_len, kvcache_block_size
|
63
|
+
|
62
64
|
|
63
|
-
|
64
|
-
|
65
|
+
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
66
|
+
if attn_impl not in ["eager", "flash_attn"]:
|
67
|
+
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
65
68
|
|
66
69
|
## Checking Constraints...
|
67
70
|
# Constraint of eager attention:
|
@@ -71,47 +74,45 @@ def validate_attention_method(
|
|
71
74
|
# 1. `max_seq_len` should be multiple of `partition_len`.
|
72
75
|
# 2. 4k <= `partition_len` <= 32k.
|
73
76
|
# 3. `max_seq_len` should be larger then 8k.
|
74
|
-
if
|
77
|
+
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
75
78
|
raise ValueError(
|
76
|
-
f"`
|
79
|
+
f"`max_seq_len` is set to {max_seq_len}, "
|
77
80
|
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
78
|
-
f"Please reduce the `
|
79
|
-
" or consider switching `
|
81
|
+
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
82
|
+
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
80
83
|
)
|
81
84
|
|
82
|
-
if
|
83
|
-
if
|
85
|
+
if attn_impl == "flash_attn":
|
86
|
+
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
84
87
|
raise ValueError(
|
85
|
-
f"`
|
88
|
+
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
86
89
|
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
87
90
|
)
|
88
|
-
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <=
|
91
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
89
92
|
raise ValueError(
|
90
|
-
f"`
|
91
|
-
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `
|
93
|
+
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
94
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
92
95
|
f"Please provide a valid value within this range."
|
93
96
|
)
|
94
|
-
elif
|
97
|
+
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
95
98
|
raise ValueError(
|
96
|
-
f"`
|
97
|
-
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `
|
98
|
-
"this requirement, or consider switching `
|
99
|
+
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
100
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
101
|
+
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
99
102
|
)
|
100
103
|
|
101
|
-
if
|
102
|
-
if
|
104
|
+
if kvcache_block_size is not None:
|
105
|
+
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
103
106
|
raise ValueError(
|
104
|
-
f" When using 'flash attention', the `
|
105
|
-
f"must always be set equal to the `
|
107
|
+
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
108
|
+
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
106
109
|
)
|
107
|
-
elif
|
110
|
+
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
108
111
|
raise ValueError(
|
109
|
-
f" When using 'eager attention', the `
|
110
|
-
f"must always be set equal to the `
|
112
|
+
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
113
|
+
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
111
114
|
)
|
112
115
|
|
113
|
-
return rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size
|
114
|
-
|
115
116
|
|
116
117
|
class DecoderOnlyWrapper(nn.Module):
|
117
118
|
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
@@ -162,16 +163,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
162
163
|
self.use_attention_mask = use_attention_mask
|
163
164
|
if self.attn_impl == "flash_attn":
|
164
165
|
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
165
|
-
if self.use_attention_mask:
|
166
|
-
register_rbln_custom_paged_flash_attention()
|
167
|
-
else:
|
168
|
-
register_rbln_custom_paged_flash_causal_attention()
|
169
166
|
elif self.attn_impl == "eager":
|
170
167
|
self.kvcache_partition_len = None
|
171
|
-
if self.use_attention_mask:
|
172
|
-
register_rbln_custom_paged_attention()
|
173
|
-
else:
|
174
|
-
register_rbln_custom_paged_causal_attention()
|
175
168
|
else:
|
176
169
|
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
177
170
|
|
@@ -191,6 +184,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
191
184
|
|
192
185
|
def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
|
193
186
|
new_layers = []
|
187
|
+
|
194
188
|
for layer in causal_lm.model.layers:
|
195
189
|
if self.attn_impl == "eager":
|
196
190
|
new_self_attn = DecoderOnlyAttention(
|
@@ -208,6 +202,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
208
202
|
|
209
203
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
210
204
|
new_layers.append(new_layer)
|
205
|
+
|
211
206
|
new_model = DecoderOnlyModel(
|
212
207
|
causal_lm.model,
|
213
208
|
new_layers,
|
@@ -227,6 +222,53 @@ class DecoderOnlyWrapper(nn.Module):
|
|
227
222
|
self._phase = phase
|
228
223
|
self.causal_lm.phase = phase
|
229
224
|
|
225
|
+
def forward_common(
|
226
|
+
self,
|
227
|
+
input_ids_or_inputs_embeds: torch.Tensor,
|
228
|
+
cache_position: torch.Tensor,
|
229
|
+
attention_mask: torch.Tensor,
|
230
|
+
query_position: torch.Tensor,
|
231
|
+
block_tables: torch.Tensor,
|
232
|
+
rotary_emb: Union[nn.Module, torch.Tensor],
|
233
|
+
*past_key_values: List[torch.Tensor],
|
234
|
+
):
|
235
|
+
if input_ids_or_inputs_embeds.ndim == 2:
|
236
|
+
input_ids = input_ids_or_inputs_embeds
|
237
|
+
inputs_embeds = None
|
238
|
+
elif input_ids_or_inputs_embeds.ndim == 3:
|
239
|
+
input_ids = None
|
240
|
+
inputs_embeds = input_ids_or_inputs_embeds
|
241
|
+
else:
|
242
|
+
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
243
|
+
|
244
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
245
|
+
raise ValueError(
|
246
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
247
|
+
)
|
248
|
+
|
249
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
250
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
251
|
+
_past_key_values = []
|
252
|
+
for i in range(self.config.num_hidden_layers):
|
253
|
+
key_states = past_key_values[i * 2]
|
254
|
+
value_states = past_key_values[i * 2 + 1]
|
255
|
+
past_key_value = [key_states, value_states]
|
256
|
+
_past_key_values.append(past_key_value)
|
257
|
+
past_key_values = _past_key_values
|
258
|
+
|
259
|
+
logit = self.causal_lm(
|
260
|
+
input_ids=input_ids,
|
261
|
+
inputs_embeds=inputs_embeds,
|
262
|
+
attention_mask=attention_mask,
|
263
|
+
cache_position=cache_position,
|
264
|
+
query_position=query_position,
|
265
|
+
past_key_values=past_key_values,
|
266
|
+
rotary_emb=rotary_emb,
|
267
|
+
block_tables=block_tables,
|
268
|
+
)
|
269
|
+
|
270
|
+
return logit
|
271
|
+
|
230
272
|
def forward(self, *args):
|
231
273
|
if self.phase == "decode":
|
232
274
|
if self.use_attention_mask:
|
@@ -269,43 +311,16 @@ class DecoderOnlyWrapper(nn.Module):
|
|
269
311
|
else:
|
270
312
|
raise ValueError(f"Unknown phase: {self.phase}")
|
271
313
|
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
if len(past_key_values) != 2 * self.num_hidden_layers:
|
282
|
-
raise ValueError(
|
283
|
-
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
284
|
-
)
|
285
|
-
|
286
|
-
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
287
|
-
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
288
|
-
_past_key_values = []
|
289
|
-
for i in range(self.config.num_hidden_layers):
|
290
|
-
key_states = past_key_values[i * 2]
|
291
|
-
value_states = past_key_values[i * 2 + 1]
|
292
|
-
past_key_value = [key_states, value_states]
|
293
|
-
_past_key_values.append(past_key_value)
|
294
|
-
past_key_values = _past_key_values
|
295
|
-
|
296
|
-
logit = self.causal_lm(
|
297
|
-
input_ids=input_ids,
|
298
|
-
inputs_embeds=inputs_embeds,
|
299
|
-
attention_mask=attention_mask,
|
300
|
-
cache_position=cache_position,
|
301
|
-
query_position=query_position,
|
302
|
-
past_key_values=past_key_values,
|
303
|
-
rotary_emb=self.rotary_emb,
|
304
|
-
block_tables=block_tables,
|
314
|
+
return self.forward_common(
|
315
|
+
input_ids_or_inputs_embeds,
|
316
|
+
cache_position,
|
317
|
+
attention_mask,
|
318
|
+
query_position,
|
319
|
+
block_tables,
|
320
|
+
self.rotary_emb,
|
321
|
+
*past_key_values,
|
305
322
|
)
|
306
323
|
|
307
|
-
return logit
|
308
|
-
|
309
324
|
|
310
325
|
class DecoderOnlyForCausalLM(nn.Module):
|
311
326
|
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
@@ -329,12 +344,13 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
329
344
|
_phase: Current processing phase ("prefill" or "decode")
|
330
345
|
"""
|
331
346
|
|
332
|
-
def __init__(self, causal_lm: PreTrainedModel, model):
|
347
|
+
def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
|
333
348
|
super().__init__()
|
334
349
|
self.config = causal_lm.config
|
335
350
|
self._original_mod = causal_lm
|
336
351
|
self.model = model
|
337
352
|
self._phase = "prefill"
|
353
|
+
self.lm_head = self._original_mod.lm_head
|
338
354
|
|
339
355
|
@property
|
340
356
|
def phase(self):
|
@@ -370,7 +386,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
370
386
|
if self.phase == "prefill":
|
371
387
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
372
388
|
|
373
|
-
logits = self.
|
389
|
+
logits = self.lm_head(hidden_states)
|
374
390
|
return logits
|
375
391
|
|
376
392
|
|
@@ -462,8 +478,12 @@ class DecoderOnlyModel(nn.Module):
|
|
462
478
|
|
463
479
|
# get cos,sin vector if needed
|
464
480
|
if rotary_emb is not None:
|
465
|
-
|
466
|
-
|
481
|
+
if isinstance(rotary_emb, torch.Tensor):
|
482
|
+
cos = rotary_emb[0]
|
483
|
+
sin = rotary_emb[1]
|
484
|
+
else:
|
485
|
+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
486
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
|
467
487
|
else:
|
468
488
|
batch_size = inputs_embeds.shape[0]
|
469
489
|
if cache_position.shape[0] > 1:
|
@@ -756,55 +776,55 @@ class AttentionOp(nn.Module):
|
|
756
776
|
if self.phase == "decode":
|
757
777
|
if self.use_attention_mask:
|
758
778
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(
|
759
|
-
query_state,
|
760
|
-
key_state,
|
761
|
-
value_state,
|
762
|
-
attn_mask,
|
763
|
-
past_key_state.unsqueeze(2),
|
764
|
-
past_value_state.unsqueeze(2),
|
765
|
-
seq_position,
|
766
|
-
scale,
|
767
|
-
block_tables,
|
768
|
-
block_size,
|
779
|
+
q=query_state,
|
780
|
+
k=key_state,
|
781
|
+
v=value_state,
|
782
|
+
mask=attn_mask,
|
783
|
+
kcache=past_key_state.unsqueeze(2),
|
784
|
+
vcache=past_value_state.unsqueeze(2),
|
785
|
+
seq=seq_position,
|
786
|
+
scale=scale,
|
787
|
+
block_table=block_tables,
|
788
|
+
block_size=block_size,
|
769
789
|
)
|
770
790
|
else:
|
771
791
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(
|
772
|
-
query_state,
|
773
|
-
key_state,
|
774
|
-
value_state,
|
775
|
-
past_key_state.unsqueeze(2),
|
776
|
-
past_value_state.unsqueeze(2),
|
777
|
-
seq_position,
|
778
|
-
scale,
|
779
|
-
block_tables,
|
780
|
-
block_size,
|
792
|
+
q=query_state,
|
793
|
+
k=key_state,
|
794
|
+
v=value_state,
|
795
|
+
kcache=past_key_state.unsqueeze(2),
|
796
|
+
vcache=past_value_state.unsqueeze(2),
|
797
|
+
seq=seq_position,
|
798
|
+
scale=scale,
|
799
|
+
block_table=block_tables,
|
800
|
+
block_size=block_size,
|
781
801
|
)
|
782
802
|
|
783
803
|
else:
|
784
804
|
if self.use_attention_mask:
|
785
805
|
attn_output = torch.ops.rbln_custom_ops.paged_attn_prefill(
|
786
|
-
query_state,
|
787
|
-
key_state,
|
788
|
-
value_state,
|
789
|
-
attn_mask,
|
790
|
-
past_key_state.unsqueeze(2),
|
791
|
-
past_value_state.unsqueeze(2),
|
792
|
-
seq_position,
|
793
|
-
scale,
|
794
|
-
block_tables,
|
795
|
-
block_size,
|
806
|
+
q=query_state,
|
807
|
+
k=key_state,
|
808
|
+
v=value_state,
|
809
|
+
mask=attn_mask,
|
810
|
+
kcache=past_key_state.unsqueeze(2),
|
811
|
+
vcache=past_value_state.unsqueeze(2),
|
812
|
+
seq=seq_position,
|
813
|
+
scale=scale,
|
814
|
+
block_table=block_tables,
|
815
|
+
block_size=block_size,
|
796
816
|
)
|
797
817
|
else:
|
798
818
|
attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_prefill(
|
799
|
-
query_state,
|
800
|
-
key_state,
|
801
|
-
value_state,
|
802
|
-
past_key_state.unsqueeze(2),
|
803
|
-
past_value_state.unsqueeze(2),
|
804
|
-
seq_position,
|
805
|
-
scale,
|
806
|
-
block_tables,
|
807
|
-
block_size,
|
819
|
+
q=query_state,
|
820
|
+
k=key_state,
|
821
|
+
v=value_state,
|
822
|
+
kcache=past_key_state.unsqueeze(2),
|
823
|
+
vcache=past_value_state.unsqueeze(2),
|
824
|
+
seq=seq_position,
|
825
|
+
scale=scale,
|
826
|
+
block_table=block_tables,
|
827
|
+
block_size=block_size,
|
808
828
|
)
|
809
829
|
|
810
830
|
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
@@ -840,7 +860,6 @@ def rotate_half(x):
|
|
840
860
|
|
841
861
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
842
862
|
"""Applies Rotary Position Embedding to the query and key tensors."""
|
843
|
-
|
844
863
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
845
864
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
846
865
|
return q_embed, k_embed
|
@@ -1015,58 +1034,58 @@ class FlashAttentionOp(AttentionOp):
|
|
1015
1034
|
if self.phase == "decode":
|
1016
1035
|
if self.use_attention_mask:
|
1017
1036
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_decode(
|
1018
|
-
query_state,
|
1019
|
-
key_state,
|
1020
|
-
value_state,
|
1021
|
-
attn_mask,
|
1022
|
-
past_key_state.unsqueeze(2),
|
1023
|
-
past_value_state.unsqueeze(2),
|
1024
|
-
seq_position,
|
1025
|
-
scale,
|
1026
|
-
block_tables,
|
1027
|
-
kvcache_block_size,
|
1028
|
-
self.kvcache_partition_size,
|
1037
|
+
q=query_state,
|
1038
|
+
k=key_state,
|
1039
|
+
v=value_state,
|
1040
|
+
mask=attn_mask,
|
1041
|
+
kcache=past_key_state.unsqueeze(2),
|
1042
|
+
vcache=past_value_state.unsqueeze(2),
|
1043
|
+
seq=seq_position,
|
1044
|
+
scale=scale,
|
1045
|
+
block_table=block_tables,
|
1046
|
+
block_size=kvcache_block_size,
|
1047
|
+
partition=self.kvcache_partition_size,
|
1029
1048
|
)
|
1030
1049
|
else:
|
1031
1050
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_decode(
|
1032
|
-
query_state,
|
1033
|
-
key_state,
|
1034
|
-
value_state,
|
1035
|
-
past_key_state.unsqueeze(2),
|
1036
|
-
past_value_state.unsqueeze(2),
|
1037
|
-
seq_position,
|
1038
|
-
scale,
|
1039
|
-
block_tables,
|
1040
|
-
kvcache_block_size,
|
1041
|
-
self.kvcache_partition_size,
|
1051
|
+
q=query_state,
|
1052
|
+
k=key_state,
|
1053
|
+
v=value_state,
|
1054
|
+
kcache=past_key_state.unsqueeze(2),
|
1055
|
+
vcache=past_value_state.unsqueeze(2),
|
1056
|
+
seq=seq_position,
|
1057
|
+
scale=scale,
|
1058
|
+
block_table=block_tables,
|
1059
|
+
block_size=kvcache_block_size,
|
1060
|
+
partition=self.kvcache_partition_size,
|
1042
1061
|
)
|
1043
1062
|
else:
|
1044
1063
|
if self.use_attention_mask:
|
1045
1064
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_attn_prefill(
|
1046
|
-
query_state,
|
1047
|
-
key_state,
|
1048
|
-
value_state,
|
1049
|
-
attn_mask,
|
1050
|
-
past_key_state.unsqueeze(2),
|
1051
|
-
past_value_state.unsqueeze(2),
|
1052
|
-
seq_position,
|
1053
|
-
scale,
|
1054
|
-
block_tables,
|
1055
|
-
kvcache_block_size,
|
1056
|
-
self.kvcache_partition_size,
|
1065
|
+
q=query_state,
|
1066
|
+
k=key_state,
|
1067
|
+
v=value_state,
|
1068
|
+
mask=attn_mask,
|
1069
|
+
kcache=past_key_state.unsqueeze(2),
|
1070
|
+
vcache=past_value_state.unsqueeze(2),
|
1071
|
+
seq=seq_position,
|
1072
|
+
scale=scale,
|
1073
|
+
block_table=block_tables,
|
1074
|
+
block_size=kvcache_block_size,
|
1075
|
+
partition=self.kvcache_partition_size,
|
1057
1076
|
)
|
1058
1077
|
else:
|
1059
1078
|
attn_output = torch.ops.rbln_custom_ops.paged_flash_causal_attn_prefill(
|
1060
|
-
query_state,
|
1061
|
-
key_state,
|
1062
|
-
value_state,
|
1063
|
-
past_key_state.unsqueeze(2),
|
1064
|
-
past_value_state.unsqueeze(2),
|
1065
|
-
seq_position,
|
1066
|
-
scale,
|
1067
|
-
block_tables,
|
1068
|
-
kvcache_block_size,
|
1069
|
-
self.kvcache_partition_size,
|
1079
|
+
q=query_state,
|
1080
|
+
k=key_state,
|
1081
|
+
v=value_state,
|
1082
|
+
kcache=past_key_state.unsqueeze(2),
|
1083
|
+
vcache=past_value_state.unsqueeze(2),
|
1084
|
+
seq=seq_position,
|
1085
|
+
scale=scale,
|
1086
|
+
block_table=block_tables,
|
1087
|
+
block_size=kvcache_block_size,
|
1088
|
+
partition=self.kvcache_partition_size,
|
1070
1089
|
)
|
1071
1090
|
|
1072
1091
|
# reshape for removing repeat_kv
|