optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +1 -1
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +231 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +87 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +46 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -36,6 +36,7 @@ from .decoderonly_architecture import (
|
|
36
36
|
DecoderOnlyWrapper,
|
37
37
|
set_default_values,
|
38
38
|
validate_attention_method,
|
39
|
+
validate_sliding_window_size,
|
39
40
|
)
|
40
41
|
|
41
42
|
|
@@ -56,36 +57,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
56
57
|
dec_attn_mask: torch.Tensor,
|
57
58
|
block_tables: torch.Tensor,
|
58
59
|
free_block_pool: Deque,
|
59
|
-
|
60
|
-
use_attention_mask: bool,
|
61
|
-
attn_impl: str,
|
62
|
-
use_position_ids: bool,
|
60
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
63
61
|
**kwargs: Any,
|
64
62
|
) -> None:
|
65
63
|
super().__init__(runtime, **kwargs)
|
66
64
|
self.phase = phase
|
67
65
|
self.batch_size = batch_size
|
68
|
-
|
69
|
-
# shared data structures between prefill and decode phase
|
70
|
-
self.use_attention_mask = use_attention_mask
|
66
|
+
self.rbln_config = rbln_config
|
71
67
|
|
72
68
|
# shared tensor between prefill and decode phase
|
73
69
|
self.dec_attn_mask = dec_attn_mask
|
74
70
|
self.block_tables = block_tables
|
75
71
|
self.free_block_pool = free_block_pool
|
76
|
-
self.use_position_ids = use_position_ids
|
77
72
|
|
78
|
-
self.kvcache_block_size = kvcache_block_size
|
79
73
|
self.empty_block = -1
|
80
|
-
self.attn_impl = attn_impl
|
81
|
-
|
82
74
|
if self.phase == "prefill":
|
83
75
|
vocab_size = kwargs.pop("vocab_size")
|
84
|
-
self.max_seq_len = kwargs.pop("max_seq_len")
|
85
|
-
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
|
86
76
|
self.output_size = [1, 1, vocab_size]
|
87
77
|
self.causal_mask = 1 - torch.triu(
|
88
|
-
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
78
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
89
79
|
)
|
90
80
|
|
91
81
|
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
|
@@ -131,31 +121,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
131
121
|
else:
|
132
122
|
raise RuntimeError(NO_BLOCKS_ERROR)
|
133
123
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
self.
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
124
|
+
def get_global_block_tables(batch_idx: int):
|
125
|
+
if self.rbln_config.cache_impl == "sliding_window":
|
126
|
+
return None
|
127
|
+
|
128
|
+
if self.phase == "prefill":
|
129
|
+
# Track previously used blocks and return them to the free_block_pool and
|
130
|
+
# reset the current batch's block table to empty blocks
|
131
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
132
|
+
self.free_block_pool.extend(prev_blocks)
|
133
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
134
|
+
|
135
|
+
# Get the start (s) and end (e) positions from cache_position and
|
136
|
+
# iterate over the cache positions to allocate necessary blocks
|
137
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
138
|
+
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
139
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
140
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
141
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
142
|
+
update_block(batch_idx, block_idx)
|
143
|
+
|
144
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
145
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
146
|
+
else:
|
147
|
+
for b_idx in range(self.batch_size):
|
148
|
+
position = cache_position[b_idx][0].item()
|
149
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
150
|
+
update_block(b_idx, block_idx)
|
151
|
+
|
152
|
+
return replace_empty_block(self.block_tables)
|
157
153
|
|
158
|
-
|
154
|
+
def get_local_block_tables(batch_idx: int):
|
155
|
+
if self.rbln_config.cache_impl == "static":
|
156
|
+
return None
|
157
|
+
else:
|
158
|
+
return (
|
159
|
+
torch.tensor([batch_idx], dtype=torch.int16)
|
160
|
+
if self.phase == "prefill"
|
161
|
+
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
162
|
+
)
|
163
|
+
|
164
|
+
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
165
|
+
|
166
|
+
def is_external_block_tables(
|
167
|
+
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
168
|
+
):
|
169
|
+
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
170
|
+
return False
|
171
|
+
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
172
|
+
return False
|
173
|
+
elif self.rbln_config.cache_impl == "hybrid":
|
174
|
+
if (block_tables is not None) != (local_block_tables is not None):
|
175
|
+
raise ValueError(
|
176
|
+
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
177
|
+
)
|
178
|
+
elif block_tables is None and local_block_tables is None:
|
179
|
+
return False
|
180
|
+
else:
|
181
|
+
return True
|
159
182
|
|
160
183
|
def forward(
|
161
184
|
self,
|
@@ -180,11 +203,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
180
203
|
else:
|
181
204
|
inputs = inputs_embeds
|
182
205
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
else:
|
187
|
-
is_external_block_tables = True
|
206
|
+
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
207
|
+
if not is_external_block_tables:
|
208
|
+
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
188
209
|
|
189
210
|
if self.phase == "decode":
|
190
211
|
return self.decode_forward(
|
@@ -204,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
204
225
|
attention_mask,
|
205
226
|
batch_idx,
|
206
227
|
block_tables,
|
228
|
+
is_external_block_tables=is_external_block_tables,
|
207
229
|
position_embed=position_embed,
|
208
230
|
token_type_ids=token_type_ids,
|
209
231
|
local_block_tables=local_block_tables,
|
@@ -229,7 +251,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
229
251
|
if batch_size != cache_position.shape[0]:
|
230
252
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
231
253
|
|
232
|
-
if self.use_attention_mask and attention_mask is None:
|
254
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
233
255
|
for b_idx in range(batch_size):
|
234
256
|
decoding_step = cache_position[b_idx].item()
|
235
257
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -245,7 +267,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
245
267
|
|
246
268
|
attention_mask = self.dec_attn_mask
|
247
269
|
|
248
|
-
if self.batch_size < block_tables.shape[0]:
|
270
|
+
if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
|
249
271
|
block_tables = block_tables[: self.batch_size]
|
250
272
|
|
251
273
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
@@ -255,9 +277,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
255
277
|
inputs,
|
256
278
|
cache_position,
|
257
279
|
block_tables,
|
280
|
+
local_block_tables,
|
258
281
|
position_embed,
|
259
|
-
attention_mask if self.use_attention_mask else None,
|
260
|
-
position_ids if self.use_position_ids else None,
|
282
|
+
attention_mask if self.rbln_config.use_attention_mask else None,
|
283
|
+
position_ids if self.rbln_config.use_position_ids else None,
|
261
284
|
)
|
262
285
|
|
263
286
|
return RBLNDecoderOnlyOutput(logits=logits)
|
@@ -268,7 +291,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
268
291
|
cache_position: torch.Tensor,
|
269
292
|
attention_mask: Optional[torch.Tensor] = None,
|
270
293
|
position_embed: Optional[torch.Tensor] = None,
|
271
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
272
294
|
token_type_ids: Optional[torch.Tensor] = None,
|
273
295
|
):
|
274
296
|
"""
|
@@ -283,15 +305,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
283
305
|
)
|
284
306
|
|
285
307
|
query_length = inputs.shape[1]
|
286
|
-
if query_length > self.max_seq_len:
|
308
|
+
if query_length > self.rbln_config.max_seq_len:
|
287
309
|
raise ValueError(
|
288
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
310
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
289
311
|
)
|
290
312
|
|
291
313
|
# Initialize attention mask for chunked processing
|
292
314
|
chunked_attention_mask = (
|
293
|
-
torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
294
|
-
if self.use_attention_mask
|
315
|
+
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
316
|
+
if self.rbln_config.use_attention_mask
|
295
317
|
else None
|
296
318
|
)
|
297
319
|
|
@@ -305,8 +327,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
305
327
|
]
|
306
328
|
|
307
329
|
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
308
|
-
|
309
|
-
|
330
|
+
padding_size = 0
|
331
|
+
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
332
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
310
333
|
# inputs_embeds
|
311
334
|
if inputs.dim() == 3:
|
312
335
|
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
@@ -351,10 +374,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
351
374
|
attention_mask: Optional[torch.Tensor] = None,
|
352
375
|
batch_idx: int = None,
|
353
376
|
block_tables: torch.Tensor = None,
|
354
|
-
is_external_block_tables: bool =
|
377
|
+
is_external_block_tables: bool = False,
|
355
378
|
position_embed: Optional[torch.Tensor] = None,
|
356
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
357
379
|
token_type_ids: Optional[torch.Tensor] = None,
|
380
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
358
381
|
) -> torch.FloatTensor:
|
359
382
|
"""
|
360
383
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -375,39 +398,47 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
375
398
|
)
|
376
399
|
|
377
400
|
# Process input in chunks of size `prefill_chunk_size`
|
378
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
401
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
379
402
|
# Extract the current chunk of inputs and cache positions
|
380
|
-
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
381
|
-
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
403
|
+
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
404
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
382
405
|
position_ids_chunk = (
|
383
|
-
position_ids[:, step : step + self.prefill_chunk_size]
|
406
|
+
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
407
|
+
if position_ids is not None
|
408
|
+
else None
|
384
409
|
)
|
385
410
|
if position_embed is not None:
|
386
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
411
|
+
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
387
412
|
|
388
|
-
if self.use_attention_mask and not self.use_position_ids:
|
413
|
+
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
389
414
|
# Update attention mask to ensure proper causal behavior
|
390
|
-
if step >= self.prefill_chunk_size:
|
391
|
-
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
392
|
-
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
415
|
+
if step >= self.rbln_config.prefill_chunk_size:
|
416
|
+
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
417
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
393
418
|
|
394
419
|
# Define query position
|
395
|
-
|
420
|
+
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
421
|
+
query_position = torch.tensor(
|
422
|
+
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
396
426
|
|
397
427
|
# Forward pass for the current chunk
|
398
428
|
logits = super().forward(
|
399
429
|
input_chunk,
|
400
430
|
cache_pos_chunk,
|
401
431
|
block_tables,
|
432
|
+
local_block_tables,
|
402
433
|
position_embed_chunk if position_embed is not None else None,
|
403
434
|
query_position,
|
404
|
-
chunked_attention_mask if self.use_attention_mask else None,
|
405
|
-
position_ids_chunk if self.use_position_ids else None,
|
435
|
+
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
436
|
+
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
406
437
|
out=out_buffers,
|
407
438
|
)
|
408
439
|
|
409
440
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
410
|
-
if not is_external_block_tables and self.use_attention_mask:
|
441
|
+
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
411
442
|
self.dec_attn_mask[batch_idx].fill_(0)
|
412
443
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
413
444
|
|
@@ -477,13 +508,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
477
508
|
dec_attn_mask=dec_attn_mask,
|
478
509
|
block_tables=block_tables,
|
479
510
|
free_block_pool=free_block_pool,
|
480
|
-
|
511
|
+
rbln_config=self.rbln_config,
|
481
512
|
vocab_size=self.config.vocab_size,
|
482
|
-
prefill_chunk_size=self.rbln_config.prefill_chunk_size,
|
483
|
-
max_seq_len=self.rbln_config.max_seq_len,
|
484
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
485
|
-
attn_impl=self.rbln_config.attn_impl,
|
486
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
487
513
|
)
|
488
514
|
|
489
515
|
self.decoders = {}
|
@@ -497,10 +523,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
497
523
|
dec_attn_mask=dec_attn_mask,
|
498
524
|
block_tables=block_tables,
|
499
525
|
free_block_pool=free_block_pool,
|
500
|
-
|
501
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
502
|
-
attn_impl=self.rbln_config.attn_impl,
|
503
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
526
|
+
rbln_config=self.rbln_config,
|
504
527
|
)
|
505
528
|
|
506
529
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -514,10 +537,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
514
537
|
subfolder: str,
|
515
538
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
516
539
|
):
|
517
|
-
|
518
|
-
|
519
|
-
store the torch tensor, weight, etc. in this function.
|
520
|
-
"""
|
540
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
541
|
+
# store the torch tensor, weight, etc. in this function.
|
521
542
|
if rbln_config.use_inputs_embeds:
|
522
543
|
save_dict = {}
|
523
544
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
@@ -625,6 +646,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
625
646
|
"use_attention_mask": rbln_config.use_attention_mask,
|
626
647
|
"use_position_ids": rbln_config.use_position_ids,
|
627
648
|
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
649
|
+
"cache_impl": rbln_config.cache_impl,
|
650
|
+
"sliding_window": rbln_config.sliding_window,
|
651
|
+
"sliding_window_layers": rbln_config.sliding_window_layers,
|
628
652
|
}
|
629
653
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
630
654
|
|
@@ -709,7 +733,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
709
733
|
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
710
734
|
):
|
711
735
|
alloc_memory_by_key[key] += sum(memory_per_node)
|
712
|
-
alloc_memory_by_key.pop("PortRecur") # kv-cache
|
736
|
+
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
737
|
+
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
713
738
|
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
714
739
|
|
715
740
|
# Get the maximum number of blocks that can be allocated
|
@@ -746,35 +771,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
746
771
|
buffer: Optional[int] = None,
|
747
772
|
num_runtimes: int = 2,
|
748
773
|
) -> int:
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
x = floor(2**21 / b * floor((k - 1) / 2**21))
|
777
|
-
"""
|
774
|
+
# We are finding max_n_blocks(x) that satisfies the following equation:
|
775
|
+
|
776
|
+
# available_dram - kernel_size - buffer
|
777
|
+
# - num_layers * 2 * tensor_parallel_size
|
778
|
+
# * align_2MB(
|
779
|
+
# x
|
780
|
+
# * block_size
|
781
|
+
# * align_64(head_dim)
|
782
|
+
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
783
|
+
# * 2
|
784
|
+
# ) > 0
|
785
|
+
|
786
|
+
# This inequality can be rewritten as follows:
|
787
|
+
|
788
|
+
# a - c * align_2MB(b * x) > 0
|
789
|
+
# where
|
790
|
+
# a = available_dram - kernel_size - buffer
|
791
|
+
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
792
|
+
# c = num_layers * 2 * tensor_parallel_size
|
793
|
+
|
794
|
+
# We can rewrite the inequality as follows:
|
795
|
+
# k > align_2MB(b*x)
|
796
|
+
# where
|
797
|
+
# k = a / c
|
798
|
+
|
799
|
+
# After that, we can derive the following equation:
|
800
|
+
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
778
801
|
|
779
802
|
def align(x: int, nbytes: int) -> int:
|
780
803
|
return int(math.ceil(x / nbytes) * nbytes)
|
@@ -833,22 +856,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
833
856
|
cls,
|
834
857
|
batch_size: int,
|
835
858
|
query_length: int,
|
836
|
-
|
837
|
-
|
838
|
-
use_position_ids: bool,
|
839
|
-
max_seq_len: int,
|
840
|
-
kvcache_block_size: int,
|
841
|
-
kvcache_num_blocks: int,
|
842
|
-
num_key_value_heads: int,
|
843
|
-
num_hidden_layers: int,
|
844
|
-
hidden_size: int,
|
845
|
-
head_dim: int,
|
859
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
860
|
+
model_config: PretrainedConfig,
|
846
861
|
):
|
847
|
-
|
862
|
+
is_prefill: bool = query_length > 1
|
863
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
864
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
865
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
866
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
867
|
+
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
868
|
+
local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
|
869
|
+
|
870
|
+
# 1. main input
|
871
|
+
if rbln_config.use_inputs_embeds:
|
848
872
|
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
849
873
|
else:
|
850
874
|
main_input = ("input_ids", [batch_size, query_length], "int64")
|
851
875
|
|
876
|
+
# 2. cache_position
|
852
877
|
input_info = [
|
853
878
|
main_input,
|
854
879
|
(
|
@@ -858,38 +883,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
858
883
|
),
|
859
884
|
]
|
860
885
|
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
865
|
-
else:
|
866
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
867
|
-
|
868
|
-
if query_length > 1:
|
886
|
+
# 3. block_tables
|
887
|
+
if rbln_config.cache_impl in ["static", "hybrid"]:
|
888
|
+
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
869
889
|
input_info.extend(
|
870
|
-
[
|
871
|
-
("query_position", [], "int16"),
|
872
|
-
]
|
890
|
+
[("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
|
873
891
|
)
|
874
|
-
if
|
892
|
+
if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
|
893
|
+
input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
|
894
|
+
|
895
|
+
# 4. query_position
|
896
|
+
if is_prefill:
|
897
|
+
input_info.extend([("query_position", [], "int16")])
|
898
|
+
|
899
|
+
# 5. attention_mask & position_ids
|
900
|
+
if rbln_config.use_attention_mask:
|
875
901
|
input_info.extend(
|
876
902
|
[
|
877
|
-
("attention_mask", [batch_size,
|
903
|
+
("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
|
904
|
+
if rbln_config.use_position_ids
|
905
|
+
else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
|
878
906
|
]
|
879
907
|
)
|
880
|
-
if use_position_ids:
|
908
|
+
if rbln_config.use_position_ids:
|
881
909
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
882
910
|
|
911
|
+
# 6. past_key_values
|
912
|
+
global_kvcache_shape = [
|
913
|
+
rbln_config.kvcache_num_blocks,
|
914
|
+
num_key_value_heads,
|
915
|
+
rbln_config.kvcache_block_size,
|
916
|
+
head_dim,
|
917
|
+
]
|
918
|
+
local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
883
919
|
input_info.extend(
|
884
920
|
[
|
885
921
|
(
|
886
922
|
f"past_key_values_{i}",
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
kvcache_block_size,
|
891
|
-
head_dim,
|
892
|
-
],
|
923
|
+
local_kvcache_shape
|
924
|
+
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
925
|
+
else global_kvcache_shape,
|
893
926
|
"float32",
|
894
927
|
)
|
895
928
|
for i in range(num_hidden_layers * 2)
|
@@ -898,6 +931,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
898
931
|
|
899
932
|
return input_info
|
900
933
|
|
934
|
+
@classmethod
|
935
|
+
def _update_sliding_window_config(
|
936
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
937
|
+
):
|
938
|
+
# Update the sliding window configuration for the RBLN model.
|
939
|
+
|
940
|
+
# This method must be implemented by subclasses to handle their specific sliding window configurations,
|
941
|
+
# as Hugging Face models use different configuration keys to represent sliding window layers.
|
942
|
+
|
943
|
+
# Args:
|
944
|
+
# model_config (PretrainedConfig): The model configuration from Hugging Face.
|
945
|
+
# rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
|
946
|
+
|
947
|
+
# Notes:
|
948
|
+
# Required configuration settings:
|
949
|
+
# - `cache_impl`: Must be one of:
|
950
|
+
# - "static": All layers use global attention (no sliding window)
|
951
|
+
# - "sliding_window": All layers use sliding window attention
|
952
|
+
# - "hybrid": A mix of global and sliding window attention layers
|
953
|
+
# - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
|
954
|
+
# - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
|
955
|
+
|
956
|
+
# Example implementation for a 'sliding_window' model:
|
957
|
+
# ```python
|
958
|
+
# rbln_config.cache_impl = "sliding_window"
|
959
|
+
# rbln_config.sliding_window = model_config.sliding_window
|
960
|
+
# rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
|
961
|
+
# return rbln_config
|
962
|
+
# ```
|
963
|
+
|
964
|
+
# Returns:
|
965
|
+
# RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
|
966
|
+
|
967
|
+
raise NotImplementedError(
|
968
|
+
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
969
|
+
"See method docstring for required configuration details."
|
970
|
+
)
|
971
|
+
|
901
972
|
@classmethod
|
902
973
|
def _update_rbln_config(
|
903
974
|
cls,
|
@@ -913,6 +984,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
913
984
|
if rbln_config.max_seq_len is None:
|
914
985
|
raise ValueError("`max_seq_len` should be specified.")
|
915
986
|
|
987
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
988
|
+
model_config, "use_sliding_window", True
|
989
|
+
):
|
990
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
991
|
+
if rbln_config.sliding_window is not None:
|
992
|
+
validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
|
993
|
+
|
916
994
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
917
995
|
attn_impl=rbln_config.attn_impl,
|
918
996
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
@@ -961,25 +1039,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
961
1039
|
"This can cause a failure during model compilation."
|
962
1040
|
)
|
963
1041
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
964
|
-
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
965
|
-
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
966
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
967
|
-
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
968
|
-
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
969
1042
|
|
970
1043
|
prefill_input_info = cls.get_input_info(
|
971
1044
|
batch_size=1,
|
972
1045
|
query_length=rbln_config.prefill_chunk_size,
|
973
|
-
|
974
|
-
|
975
|
-
use_position_ids=rbln_config.use_position_ids,
|
976
|
-
max_seq_len=rbln_config.max_seq_len,
|
977
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
978
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
979
|
-
num_key_value_heads=num_key_value_heads,
|
980
|
-
num_hidden_layers=num_hidden_layers,
|
981
|
-
hidden_size=hidden_size,
|
982
|
-
head_dim=head_dim,
|
1046
|
+
rbln_config=rbln_config,
|
1047
|
+
model_config=model_config,
|
983
1048
|
)
|
984
1049
|
|
985
1050
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
@@ -989,16 +1054,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
989
1054
|
dec_input_info = cls.get_input_info(
|
990
1055
|
batch_size=batch_size,
|
991
1056
|
query_length=1,
|
992
|
-
|
993
|
-
|
994
|
-
use_position_ids=rbln_config.use_position_ids,
|
995
|
-
max_seq_len=rbln_config.max_seq_len,
|
996
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
997
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
998
|
-
num_key_value_heads=num_key_value_heads,
|
999
|
-
num_hidden_layers=num_hidden_layers,
|
1000
|
-
hidden_size=hidden_size,
|
1001
|
-
head_dim=head_dim,
|
1057
|
+
rbln_config=rbln_config,
|
1058
|
+
model_config=model_config,
|
1002
1059
|
)
|
1003
1060
|
dec_compile_configs.append(
|
1004
1061
|
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
@@ -1122,12 +1179,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1122
1179
|
return_dict: Optional[torch.Tensor] = None,
|
1123
1180
|
**kwargs,
|
1124
1181
|
) -> Tuple[torch.FloatTensor]:
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
"""
|
1182
|
+
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
1183
|
+
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
1184
|
+
# A for-loop ensures synchronization with the HuggingFace generate API.
|
1185
|
+
# The decoder stage operates as usual, processing inputs in batch mode.
|
1186
|
+
|
1131
1187
|
# Prefll
|
1132
1188
|
if cache_position is None:
|
1133
1189
|
logits = []
|