optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +24 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/__init__.py +21 -1
- optimum/rbln/diffusers/configurations/__init__.py +4 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
- optimum/rbln/diffusers/modeling_diffusers.py +72 -65
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
- optimum/rbln/diffusers/models/controlnet.py +14 -8
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
- optimum/rbln/diffusers/pipelines/__init__.py +10 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +71 -37
- optimum/rbln/modeling_base.py +63 -109
- optimum/rbln/transformers/__init__.py +41 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +21 -22
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +54 -4
- optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
- optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
- optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
- optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
- optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
- optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/runtime_utils.py +49 -1
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
- optimum_rbln-0.8.1.dist-info/RECORD +211 -0
- optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -36,6 +36,7 @@ from .decoderonly_architecture import (
|
|
36
36
|
DecoderOnlyWrapper,
|
37
37
|
set_default_values,
|
38
38
|
validate_attention_method,
|
39
|
+
validate_sliding_window_size,
|
39
40
|
)
|
40
41
|
|
41
42
|
|
@@ -56,39 +57,28 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
56
57
|
dec_attn_mask: torch.Tensor,
|
57
58
|
block_tables: torch.Tensor,
|
58
59
|
free_block_pool: Deque,
|
59
|
-
|
60
|
-
use_attention_mask: bool,
|
61
|
-
attn_impl: str,
|
62
|
-
use_position_ids: bool,
|
60
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
63
61
|
**kwargs: Any,
|
64
62
|
) -> None:
|
65
63
|
super().__init__(runtime, **kwargs)
|
66
64
|
self.phase = phase
|
67
65
|
self.batch_size = batch_size
|
68
|
-
|
69
|
-
# shared data structures between prefill and decode phase
|
70
|
-
self.use_attention_mask = use_attention_mask
|
66
|
+
self.rbln_config = rbln_config
|
71
67
|
|
72
68
|
# shared tensor between prefill and decode phase
|
73
69
|
self.dec_attn_mask = dec_attn_mask
|
74
70
|
self.block_tables = block_tables
|
75
71
|
self.free_block_pool = free_block_pool
|
76
|
-
self.use_position_ids = use_position_ids
|
77
72
|
|
78
|
-
self.kvcache_block_size = kvcache_block_size
|
79
73
|
self.empty_block = -1
|
80
|
-
self.attn_impl = attn_impl
|
81
|
-
|
82
74
|
if self.phase == "prefill":
|
83
75
|
vocab_size = kwargs.pop("vocab_size")
|
84
|
-
self.max_seq_len = kwargs.pop("max_seq_len")
|
85
|
-
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
|
86
76
|
self.output_size = [1, 1, vocab_size]
|
87
77
|
self.causal_mask = 1 - torch.triu(
|
88
|
-
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
78
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
89
79
|
)
|
90
80
|
|
91
|
-
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
|
81
|
+
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
|
92
82
|
"""
|
93
83
|
Manages and returns the KV cache block tables.
|
94
84
|
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
@@ -98,7 +88,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
98
88
|
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
99
89
|
|
100
90
|
Returns:
|
101
|
-
|
91
|
+
Updated block tables.
|
102
92
|
"""
|
103
93
|
|
104
94
|
NO_BLOCKS_ERROR = (
|
@@ -131,31 +121,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
131
121
|
else:
|
132
122
|
raise RuntimeError(NO_BLOCKS_ERROR)
|
133
123
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
self.
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
124
|
+
def get_global_block_tables(batch_idx: int):
|
125
|
+
if self.rbln_config.cache_impl == "sliding_window":
|
126
|
+
return None
|
127
|
+
|
128
|
+
if self.phase == "prefill":
|
129
|
+
# Track previously used blocks and return them to the free_block_pool and
|
130
|
+
# reset the current batch's block table to empty blocks
|
131
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
132
|
+
self.free_block_pool.extend(prev_blocks)
|
133
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
134
|
+
|
135
|
+
# Get the start (s) and end (e) positions from cache_position and
|
136
|
+
# iterate over the cache positions to allocate necessary blocks
|
137
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
138
|
+
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
139
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
140
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
141
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
142
|
+
update_block(batch_idx, block_idx)
|
143
|
+
|
144
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
145
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
146
|
+
else:
|
147
|
+
for b_idx in range(self.batch_size):
|
148
|
+
position = cache_position[b_idx][0].item()
|
149
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
150
|
+
update_block(b_idx, block_idx)
|
157
151
|
|
158
|
-
|
152
|
+
return replace_empty_block(self.block_tables)
|
153
|
+
|
154
|
+
def get_local_block_tables(batch_idx: int):
|
155
|
+
if self.rbln_config.cache_impl == "static":
|
156
|
+
return None
|
157
|
+
else:
|
158
|
+
return (
|
159
|
+
torch.tensor([batch_idx], dtype=torch.int16)
|
160
|
+
if self.phase == "prefill"
|
161
|
+
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
162
|
+
)
|
163
|
+
|
164
|
+
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
165
|
+
|
166
|
+
def is_external_block_tables(
|
167
|
+
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
168
|
+
):
|
169
|
+
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
170
|
+
return False
|
171
|
+
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
172
|
+
return False
|
173
|
+
elif self.rbln_config.cache_impl == "hybrid":
|
174
|
+
if (block_tables is not None) != (local_block_tables is not None):
|
175
|
+
raise ValueError(
|
176
|
+
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
177
|
+
)
|
178
|
+
elif block_tables is None and local_block_tables is None:
|
179
|
+
return False
|
180
|
+
|
181
|
+
return True
|
159
182
|
|
160
183
|
def forward(
|
161
184
|
self,
|
@@ -180,11 +203,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
180
203
|
else:
|
181
204
|
inputs = inputs_embeds
|
182
205
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
else:
|
187
|
-
is_external_block_tables = True
|
206
|
+
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
207
|
+
if not is_external_block_tables:
|
208
|
+
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
188
209
|
|
189
210
|
if self.phase == "decode":
|
190
211
|
return self.decode_forward(
|
@@ -204,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
204
225
|
attention_mask,
|
205
226
|
batch_idx,
|
206
227
|
block_tables,
|
228
|
+
is_external_block_tables=is_external_block_tables,
|
207
229
|
position_embed=position_embed,
|
208
230
|
token_type_ids=token_type_ids,
|
209
231
|
local_block_tables=local_block_tables,
|
@@ -229,7 +251,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
229
251
|
if batch_size != cache_position.shape[0]:
|
230
252
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
231
253
|
|
232
|
-
if self.use_attention_mask and attention_mask is None:
|
254
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
233
255
|
for b_idx in range(batch_size):
|
234
256
|
decoding_step = cache_position[b_idx].item()
|
235
257
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -245,7 +267,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
245
267
|
|
246
268
|
attention_mask = self.dec_attn_mask
|
247
269
|
|
248
|
-
if self.batch_size < block_tables.shape[0]:
|
270
|
+
if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
|
249
271
|
block_tables = block_tables[: self.batch_size]
|
250
272
|
|
251
273
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
@@ -255,9 +277,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
255
277
|
inputs,
|
256
278
|
cache_position,
|
257
279
|
block_tables,
|
280
|
+
local_block_tables,
|
258
281
|
position_embed,
|
259
|
-
attention_mask if self.use_attention_mask else None,
|
260
|
-
position_ids if self.use_position_ids else None,
|
282
|
+
attention_mask if self.rbln_config.use_attention_mask else None,
|
283
|
+
position_ids if self.rbln_config.use_position_ids else None,
|
261
284
|
)
|
262
285
|
|
263
286
|
return RBLNDecoderOnlyOutput(logits=logits)
|
@@ -268,7 +291,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
268
291
|
cache_position: torch.Tensor,
|
269
292
|
attention_mask: Optional[torch.Tensor] = None,
|
270
293
|
position_embed: Optional[torch.Tensor] = None,
|
271
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
272
294
|
token_type_ids: Optional[torch.Tensor] = None,
|
273
295
|
):
|
274
296
|
"""
|
@@ -283,15 +305,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
283
305
|
)
|
284
306
|
|
285
307
|
query_length = inputs.shape[1]
|
286
|
-
if query_length > self.max_seq_len:
|
308
|
+
if query_length > self.rbln_config.max_seq_len:
|
287
309
|
raise ValueError(
|
288
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
310
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
289
311
|
)
|
290
312
|
|
291
313
|
# Initialize attention mask for chunked processing
|
292
314
|
chunked_attention_mask = (
|
293
|
-
torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
294
|
-
if self.use_attention_mask
|
315
|
+
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
316
|
+
if self.rbln_config.use_attention_mask
|
295
317
|
else None
|
296
318
|
)
|
297
319
|
|
@@ -305,8 +327,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
305
327
|
]
|
306
328
|
|
307
329
|
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
308
|
-
|
309
|
-
|
330
|
+
padding_size = 0
|
331
|
+
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
332
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
310
333
|
# inputs_embeds
|
311
334
|
if inputs.dim() == 3:
|
312
335
|
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
@@ -351,10 +374,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
351
374
|
attention_mask: Optional[torch.Tensor] = None,
|
352
375
|
batch_idx: int = None,
|
353
376
|
block_tables: torch.Tensor = None,
|
354
|
-
is_external_block_tables: bool =
|
377
|
+
is_external_block_tables: bool = False,
|
355
378
|
position_embed: Optional[torch.Tensor] = None,
|
356
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
357
379
|
token_type_ids: Optional[torch.Tensor] = None,
|
380
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
358
381
|
) -> torch.FloatTensor:
|
359
382
|
"""
|
360
383
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -375,39 +398,47 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
375
398
|
)
|
376
399
|
|
377
400
|
# Process input in chunks of size `prefill_chunk_size`
|
378
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
401
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
379
402
|
# Extract the current chunk of inputs and cache positions
|
380
|
-
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
381
|
-
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
403
|
+
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
404
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
382
405
|
position_ids_chunk = (
|
383
|
-
position_ids[:, step : step + self.prefill_chunk_size]
|
406
|
+
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
407
|
+
if position_ids is not None
|
408
|
+
else None
|
384
409
|
)
|
385
410
|
if position_embed is not None:
|
386
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
411
|
+
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
387
412
|
|
388
|
-
if self.use_attention_mask and not self.use_position_ids:
|
413
|
+
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
389
414
|
# Update attention mask to ensure proper causal behavior
|
390
|
-
if step >= self.prefill_chunk_size:
|
391
|
-
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
392
|
-
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
415
|
+
if step >= self.rbln_config.prefill_chunk_size:
|
416
|
+
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
417
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
393
418
|
|
394
419
|
# Define query position
|
395
|
-
|
420
|
+
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
421
|
+
query_position = torch.tensor(
|
422
|
+
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
396
426
|
|
397
427
|
# Forward pass for the current chunk
|
398
428
|
logits = super().forward(
|
399
429
|
input_chunk,
|
400
430
|
cache_pos_chunk,
|
401
431
|
block_tables,
|
432
|
+
local_block_tables,
|
402
433
|
position_embed_chunk if position_embed is not None else None,
|
403
434
|
query_position,
|
404
|
-
chunked_attention_mask if self.use_attention_mask else None,
|
405
|
-
position_ids_chunk if self.use_position_ids else None,
|
435
|
+
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
436
|
+
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
406
437
|
out=out_buffers,
|
407
438
|
)
|
408
439
|
|
409
440
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
410
|
-
if not is_external_block_tables and self.use_attention_mask:
|
441
|
+
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
411
442
|
self.dec_attn_mask[batch_idx].fill_(0)
|
412
443
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
413
444
|
|
@@ -427,6 +458,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
427
458
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
428
459
|
|
429
460
|
The class provides core functionality for:
|
461
|
+
|
430
462
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
431
463
|
2. Handling the compilation process for RBLN devices
|
432
464
|
3. Managing inference operations for causal language modeling
|
@@ -477,13 +509,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
477
509
|
dec_attn_mask=dec_attn_mask,
|
478
510
|
block_tables=block_tables,
|
479
511
|
free_block_pool=free_block_pool,
|
480
|
-
|
512
|
+
rbln_config=self.rbln_config,
|
481
513
|
vocab_size=self.config.vocab_size,
|
482
|
-
prefill_chunk_size=self.rbln_config.prefill_chunk_size,
|
483
|
-
max_seq_len=self.rbln_config.max_seq_len,
|
484
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
485
|
-
attn_impl=self.rbln_config.attn_impl,
|
486
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
487
514
|
)
|
488
515
|
|
489
516
|
self.decoders = {}
|
@@ -497,10 +524,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
497
524
|
dec_attn_mask=dec_attn_mask,
|
498
525
|
block_tables=block_tables,
|
499
526
|
free_block_pool=free_block_pool,
|
500
|
-
|
501
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
502
|
-
attn_impl=self.rbln_config.attn_impl,
|
503
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
527
|
+
rbln_config=self.rbln_config,
|
504
528
|
)
|
505
529
|
|
506
530
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -509,15 +533,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
509
533
|
@classmethod
|
510
534
|
def save_torch_artifacts(
|
511
535
|
cls,
|
512
|
-
model:
|
536
|
+
model: PreTrainedModel,
|
513
537
|
save_dir_path: Path,
|
514
538
|
subfolder: str,
|
515
539
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
516
540
|
):
|
517
|
-
|
518
|
-
|
519
|
-
store the torch tensor, weight, etc. in this function.
|
520
|
-
"""
|
541
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
542
|
+
# store the torch tensor, weight, etc. in this function.
|
521
543
|
if rbln_config.use_inputs_embeds:
|
522
544
|
save_dict = {}
|
523
545
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
@@ -545,7 +567,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
545
567
|
def get_quantized_model(
|
546
568
|
cls,
|
547
569
|
model_id: str,
|
548
|
-
config: Optional[
|
570
|
+
config: Optional[PretrainedConfig] = None,
|
549
571
|
use_auth_token: Optional[Union[bool, str]] = None,
|
550
572
|
revision: Optional[str] = None,
|
551
573
|
force_download: bool = False,
|
@@ -584,16 +606,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
584
606
|
return model
|
585
607
|
|
586
608
|
def __getattr__(self, __name: str) -> Any:
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
595
|
-
2. For other attributes: Returns them directly from the original class
|
596
|
-
"""
|
609
|
+
# Special method to delegate attribute access to the original Huggingface LM class.
|
610
|
+
# This method is called when an attribute is not found in the current instance's dictionary.
|
611
|
+
# It enables transparent access to the original model's attributes and methods while maintaining
|
612
|
+
# proper method binding.
|
613
|
+
|
614
|
+
# The method implements a delegation pattern that:
|
615
|
+
|
616
|
+
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
617
|
+
# 2. For other attributes: Returns them directly from the original class
|
597
618
|
|
598
619
|
def redirect(func):
|
599
620
|
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
@@ -606,7 +627,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
606
627
|
@classmethod
|
607
628
|
def get_pytorch_model(
|
608
629
|
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
609
|
-
) ->
|
630
|
+
) -> PreTrainedModel:
|
610
631
|
if rbln_config and rbln_config.quantization:
|
611
632
|
model = cls.get_quantized_model(*args, **kwargs)
|
612
633
|
else:
|
@@ -615,7 +636,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
615
636
|
return model
|
616
637
|
|
617
638
|
@classmethod
|
618
|
-
def wrap_model_if_needed(cls, model:
|
639
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
619
640
|
wrapper_cfg = {
|
620
641
|
"max_seq_len": rbln_config.max_seq_len,
|
621
642
|
"attn_impl": rbln_config.attn_impl,
|
@@ -625,12 +646,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
625
646
|
"use_attention_mask": rbln_config.use_attention_mask,
|
626
647
|
"use_position_ids": rbln_config.use_position_ids,
|
627
648
|
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
649
|
+
"cache_impl": rbln_config.cache_impl,
|
650
|
+
"sliding_window": rbln_config.sliding_window,
|
651
|
+
"sliding_window_layers": rbln_config.sliding_window_layers,
|
628
652
|
}
|
629
653
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
630
654
|
|
631
655
|
@classmethod
|
632
656
|
@torch.inference_mode()
|
633
|
-
def get_compiled_model(cls, model:
|
657
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
634
658
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
635
659
|
|
636
660
|
rbln_compile_configs = rbln_config.compile_cfgs
|
@@ -655,9 +679,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
655
679
|
quantization.maybe_set_quantization_env()
|
656
680
|
original_linear = torch.nn.functional.linear
|
657
681
|
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
658
|
-
compiled_model =
|
682
|
+
compiled_model = cls.compile(
|
659
683
|
wrapped_model,
|
660
684
|
compile_config,
|
685
|
+
create_runtimes=rbln_config.create_runtimes,
|
686
|
+
device=rbln_config.device,
|
661
687
|
example_inputs=example_inputs,
|
662
688
|
compile_context=compile_context,
|
663
689
|
)
|
@@ -709,7 +735,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
709
735
|
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
710
736
|
):
|
711
737
|
alloc_memory_by_key[key] += sum(memory_per_node)
|
712
|
-
alloc_memory_by_key.pop("PortRecur") # kv-cache
|
738
|
+
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
739
|
+
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
713
740
|
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
714
741
|
|
715
742
|
# Get the maximum number of blocks that can be allocated
|
@@ -746,35 +773,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
746
773
|
buffer: Optional[int] = None,
|
747
774
|
num_runtimes: int = 2,
|
748
775
|
) -> int:
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
x = floor(2**21 / b * floor((k - 1) / 2**21))
|
777
|
-
"""
|
776
|
+
# We are finding max_n_blocks(x) that satisfies the following equation:
|
777
|
+
|
778
|
+
# available_dram - kernel_size - buffer
|
779
|
+
# - num_layers * 2 * tensor_parallel_size
|
780
|
+
# * align_2MB(
|
781
|
+
# x
|
782
|
+
# * block_size
|
783
|
+
# * align_64(head_dim)
|
784
|
+
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
785
|
+
# * 2
|
786
|
+
# ) > 0
|
787
|
+
|
788
|
+
# This inequality can be rewritten as follows:
|
789
|
+
|
790
|
+
# a - c * align_2MB(b * x) > 0
|
791
|
+
# where
|
792
|
+
# a = available_dram - kernel_size - buffer
|
793
|
+
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
794
|
+
# c = num_layers * 2 * tensor_parallel_size
|
795
|
+
|
796
|
+
# We can rewrite the inequality as follows:
|
797
|
+
# k > align_2MB(b*x)
|
798
|
+
# where
|
799
|
+
# k = a / c
|
800
|
+
|
801
|
+
# After that, we can derive the following equation:
|
802
|
+
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
778
803
|
|
779
804
|
def align(x: int, nbytes: int) -> int:
|
780
805
|
return int(math.ceil(x / nbytes) * nbytes)
|
@@ -833,22 +858,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
833
858
|
cls,
|
834
859
|
batch_size: int,
|
835
860
|
query_length: int,
|
836
|
-
|
837
|
-
|
838
|
-
use_position_ids: bool,
|
839
|
-
max_seq_len: int,
|
840
|
-
kvcache_block_size: int,
|
841
|
-
kvcache_num_blocks: int,
|
842
|
-
num_key_value_heads: int,
|
843
|
-
num_hidden_layers: int,
|
844
|
-
hidden_size: int,
|
845
|
-
head_dim: int,
|
861
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
862
|
+
model_config: PretrainedConfig,
|
846
863
|
):
|
847
|
-
|
864
|
+
is_prefill: bool = query_length > 1
|
865
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
866
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
867
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
868
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
869
|
+
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
870
|
+
local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
|
871
|
+
|
872
|
+
# 1. main input
|
873
|
+
if rbln_config.use_inputs_embeds:
|
848
874
|
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
849
875
|
else:
|
850
876
|
main_input = ("input_ids", [batch_size, query_length], "int64")
|
851
877
|
|
878
|
+
# 2. cache_position
|
852
879
|
input_info = [
|
853
880
|
main_input,
|
854
881
|
(
|
@@ -858,38 +885,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
858
885
|
),
|
859
886
|
]
|
860
887
|
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
865
|
-
else:
|
866
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
867
|
-
|
868
|
-
if query_length > 1:
|
888
|
+
# 3. block_tables
|
889
|
+
if rbln_config.cache_impl in ["static", "hybrid"]:
|
890
|
+
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
869
891
|
input_info.extend(
|
870
|
-
[
|
871
|
-
("query_position", [], "int16"),
|
872
|
-
]
|
892
|
+
[("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
|
873
893
|
)
|
874
|
-
if
|
894
|
+
if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
|
895
|
+
input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
|
896
|
+
|
897
|
+
# 4. query_position
|
898
|
+
if is_prefill:
|
899
|
+
input_info.extend([("query_position", [], "int16")])
|
900
|
+
|
901
|
+
# 5. attention_mask & position_ids
|
902
|
+
if rbln_config.use_attention_mask:
|
875
903
|
input_info.extend(
|
876
904
|
[
|
877
|
-
("attention_mask", [batch_size,
|
905
|
+
("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
|
906
|
+
if rbln_config.use_position_ids
|
907
|
+
else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
|
878
908
|
]
|
879
909
|
)
|
880
|
-
if use_position_ids:
|
910
|
+
if rbln_config.use_position_ids:
|
881
911
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
882
912
|
|
913
|
+
# 6. past_key_values
|
914
|
+
global_kvcache_shape = [
|
915
|
+
rbln_config.kvcache_num_blocks,
|
916
|
+
num_key_value_heads,
|
917
|
+
rbln_config.kvcache_block_size,
|
918
|
+
head_dim,
|
919
|
+
]
|
920
|
+
local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
883
921
|
input_info.extend(
|
884
922
|
[
|
885
923
|
(
|
886
924
|
f"past_key_values_{i}",
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
kvcache_block_size,
|
891
|
-
head_dim,
|
892
|
-
],
|
925
|
+
local_kvcache_shape
|
926
|
+
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
927
|
+
else global_kvcache_shape,
|
893
928
|
"float32",
|
894
929
|
)
|
895
930
|
for i in range(num_hidden_layers * 2)
|
@@ -898,12 +933,50 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
898
933
|
|
899
934
|
return input_info
|
900
935
|
|
936
|
+
@classmethod
|
937
|
+
def _update_sliding_window_config(
|
938
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
939
|
+
):
|
940
|
+
# Update the sliding window configuration for the RBLN model.
|
941
|
+
|
942
|
+
# This method must be implemented by subclasses to handle their specific sliding window configurations,
|
943
|
+
# as Hugging Face models use different configuration keys to represent sliding window layers.
|
944
|
+
|
945
|
+
# Args:
|
946
|
+
# model_config (PretrainedConfig): The model configuration from Hugging Face.
|
947
|
+
# rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
|
948
|
+
|
949
|
+
# Notes:
|
950
|
+
# Required configuration settings:
|
951
|
+
# - `cache_impl`: Must be one of:
|
952
|
+
# - "static": All layers use global attention (no sliding window)
|
953
|
+
# - "sliding_window": All layers use sliding window attention
|
954
|
+
# - "hybrid": A mix of global and sliding window attention layers
|
955
|
+
# - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
|
956
|
+
# - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
|
957
|
+
|
958
|
+
# Example implementation for a 'sliding_window' model:
|
959
|
+
# ```python
|
960
|
+
# rbln_config.cache_impl = "sliding_window"
|
961
|
+
# rbln_config.sliding_window = model_config.sliding_window
|
962
|
+
# rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
|
963
|
+
# return rbln_config
|
964
|
+
# ```
|
965
|
+
|
966
|
+
# Returns:
|
967
|
+
# RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
|
968
|
+
|
969
|
+
raise NotImplementedError(
|
970
|
+
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
971
|
+
"See method docstring for required configuration details."
|
972
|
+
)
|
973
|
+
|
901
974
|
@classmethod
|
902
975
|
def _update_rbln_config(
|
903
976
|
cls,
|
904
977
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
905
|
-
model: Optional[
|
906
|
-
model_config: Optional[
|
978
|
+
model: Optional[PreTrainedModel] = None,
|
979
|
+
model_config: Optional[PretrainedConfig] = None,
|
907
980
|
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
908
981
|
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
909
982
|
if rbln_config.max_seq_len is None:
|
@@ -913,6 +986,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
913
986
|
if rbln_config.max_seq_len is None:
|
914
987
|
raise ValueError("`max_seq_len` should be specified.")
|
915
988
|
|
989
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
990
|
+
model_config, "use_sliding_window", True
|
991
|
+
):
|
992
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
993
|
+
if rbln_config.sliding_window is not None:
|
994
|
+
validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
|
995
|
+
|
916
996
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
917
997
|
attn_impl=rbln_config.attn_impl,
|
918
998
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
@@ -961,25 +1041,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
961
1041
|
"This can cause a failure during model compilation."
|
962
1042
|
)
|
963
1043
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
964
|
-
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
965
|
-
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
966
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
967
|
-
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
968
|
-
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
969
1044
|
|
970
1045
|
prefill_input_info = cls.get_input_info(
|
971
1046
|
batch_size=1,
|
972
1047
|
query_length=rbln_config.prefill_chunk_size,
|
973
|
-
|
974
|
-
|
975
|
-
use_position_ids=rbln_config.use_position_ids,
|
976
|
-
max_seq_len=rbln_config.max_seq_len,
|
977
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
978
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
979
|
-
num_key_value_heads=num_key_value_heads,
|
980
|
-
num_hidden_layers=num_hidden_layers,
|
981
|
-
hidden_size=hidden_size,
|
982
|
-
head_dim=head_dim,
|
1048
|
+
rbln_config=rbln_config,
|
1049
|
+
model_config=model_config,
|
983
1050
|
)
|
984
1051
|
|
985
1052
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
@@ -989,16 +1056,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
989
1056
|
dec_input_info = cls.get_input_info(
|
990
1057
|
batch_size=batch_size,
|
991
1058
|
query_length=1,
|
992
|
-
|
993
|
-
|
994
|
-
use_position_ids=rbln_config.use_position_ids,
|
995
|
-
max_seq_len=rbln_config.max_seq_len,
|
996
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
997
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
998
|
-
num_key_value_heads=num_key_value_heads,
|
999
|
-
num_hidden_layers=num_hidden_layers,
|
1000
|
-
hidden_size=hidden_size,
|
1001
|
-
head_dim=head_dim,
|
1059
|
+
rbln_config=rbln_config,
|
1060
|
+
model_config=model_config,
|
1002
1061
|
)
|
1003
1062
|
dec_compile_configs.append(
|
1004
1063
|
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
@@ -1122,12 +1181,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1122
1181
|
return_dict: Optional[torch.Tensor] = None,
|
1123
1182
|
**kwargs,
|
1124
1183
|
) -> Tuple[torch.FloatTensor]:
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
"""
|
1184
|
+
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
1185
|
+
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
1186
|
+
# A for-loop ensures synchronization with the HuggingFace generate API.
|
1187
|
+
# The decoder stage operates as usual, processing inputs in batch mode.
|
1188
|
+
|
1131
1189
|
# Prefll
|
1132
1190
|
if cache_position is None:
|
1133
1191
|
logits = []
|