optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +53 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -36,6 +36,7 @@ from .decoderonly_architecture import (
|
|
36
36
|
DecoderOnlyWrapper,
|
37
37
|
set_default_values,
|
38
38
|
validate_attention_method,
|
39
|
+
validate_sliding_window_size,
|
39
40
|
)
|
40
41
|
|
41
42
|
|
@@ -56,36 +57,25 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
56
57
|
dec_attn_mask: torch.Tensor,
|
57
58
|
block_tables: torch.Tensor,
|
58
59
|
free_block_pool: Deque,
|
59
|
-
|
60
|
-
use_attention_mask: bool,
|
61
|
-
attn_impl: str,
|
62
|
-
use_position_ids: bool,
|
60
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
63
61
|
**kwargs: Any,
|
64
62
|
) -> None:
|
65
63
|
super().__init__(runtime, **kwargs)
|
66
64
|
self.phase = phase
|
67
65
|
self.batch_size = batch_size
|
68
|
-
|
69
|
-
# shared data structures between prefill and decode phase
|
70
|
-
self.use_attention_mask = use_attention_mask
|
66
|
+
self.rbln_config = rbln_config
|
71
67
|
|
72
68
|
# shared tensor between prefill and decode phase
|
73
69
|
self.dec_attn_mask = dec_attn_mask
|
74
70
|
self.block_tables = block_tables
|
75
71
|
self.free_block_pool = free_block_pool
|
76
|
-
self.use_position_ids = use_position_ids
|
77
72
|
|
78
|
-
self.kvcache_block_size = kvcache_block_size
|
79
73
|
self.empty_block = -1
|
80
|
-
self.attn_impl = attn_impl
|
81
|
-
|
82
74
|
if self.phase == "prefill":
|
83
75
|
vocab_size = kwargs.pop("vocab_size")
|
84
|
-
self.max_seq_len = kwargs.pop("max_seq_len")
|
85
|
-
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
|
86
76
|
self.output_size = [1, 1, vocab_size]
|
87
77
|
self.causal_mask = 1 - torch.triu(
|
88
|
-
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
78
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
89
79
|
)
|
90
80
|
|
91
81
|
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
|
@@ -131,31 +121,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
131
121
|
else:
|
132
122
|
raise RuntimeError(NO_BLOCKS_ERROR)
|
133
123
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
self.
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
124
|
+
def get_global_block_tables(batch_idx: int):
|
125
|
+
if self.rbln_config.cache_impl == "sliding_window":
|
126
|
+
return None
|
127
|
+
|
128
|
+
if self.phase == "prefill":
|
129
|
+
# Track previously used blocks and return them to the free_block_pool and
|
130
|
+
# reset the current batch's block table to empty blocks
|
131
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
132
|
+
self.free_block_pool.extend(prev_blocks)
|
133
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
134
|
+
|
135
|
+
# Get the start (s) and end (e) positions from cache_position and
|
136
|
+
# iterate over the cache positions to allocate necessary blocks
|
137
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
138
|
+
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
139
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
140
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
141
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
142
|
+
update_block(batch_idx, block_idx)
|
143
|
+
|
144
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
145
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
146
|
+
else:
|
147
|
+
for b_idx in range(self.batch_size):
|
148
|
+
position = cache_position[b_idx][0].item()
|
149
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
150
|
+
update_block(b_idx, block_idx)
|
151
|
+
|
152
|
+
return replace_empty_block(self.block_tables)
|
157
153
|
|
158
|
-
|
154
|
+
def get_local_block_tables(batch_idx: int):
|
155
|
+
if self.rbln_config.cache_impl == "static":
|
156
|
+
return None
|
157
|
+
else:
|
158
|
+
return (
|
159
|
+
torch.tensor([batch_idx], dtype=torch.int16)
|
160
|
+
if self.phase == "prefill"
|
161
|
+
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
162
|
+
)
|
163
|
+
|
164
|
+
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
165
|
+
|
166
|
+
def is_external_block_tables(
|
167
|
+
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
168
|
+
):
|
169
|
+
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
170
|
+
return False
|
171
|
+
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
172
|
+
return False
|
173
|
+
elif self.rbln_config.cache_impl == "hybrid":
|
174
|
+
if (block_tables is not None) != (local_block_tables is not None):
|
175
|
+
raise ValueError(
|
176
|
+
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
177
|
+
)
|
178
|
+
elif block_tables is None and local_block_tables is None:
|
179
|
+
return False
|
180
|
+
else:
|
181
|
+
return True
|
159
182
|
|
160
183
|
def forward(
|
161
184
|
self,
|
@@ -180,11 +203,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
180
203
|
else:
|
181
204
|
inputs = inputs_embeds
|
182
205
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
else:
|
187
|
-
is_external_block_tables = True
|
206
|
+
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
207
|
+
if not is_external_block_tables:
|
208
|
+
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
188
209
|
|
189
210
|
if self.phase == "decode":
|
190
211
|
return self.decode_forward(
|
@@ -204,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
204
225
|
attention_mask,
|
205
226
|
batch_idx,
|
206
227
|
block_tables,
|
228
|
+
is_external_block_tables=is_external_block_tables,
|
207
229
|
position_embed=position_embed,
|
208
230
|
token_type_ids=token_type_ids,
|
209
231
|
local_block_tables=local_block_tables,
|
@@ -229,7 +251,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
229
251
|
if batch_size != cache_position.shape[0]:
|
230
252
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
231
253
|
|
232
|
-
if self.use_attention_mask and attention_mask is None:
|
254
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
233
255
|
for b_idx in range(batch_size):
|
234
256
|
decoding_step = cache_position[b_idx].item()
|
235
257
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -245,7 +267,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
245
267
|
|
246
268
|
attention_mask = self.dec_attn_mask
|
247
269
|
|
248
|
-
if self.batch_size < block_tables.shape[0]:
|
270
|
+
if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
|
249
271
|
block_tables = block_tables[: self.batch_size]
|
250
272
|
|
251
273
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
@@ -255,9 +277,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
255
277
|
inputs,
|
256
278
|
cache_position,
|
257
279
|
block_tables,
|
280
|
+
local_block_tables,
|
258
281
|
position_embed,
|
259
|
-
attention_mask if self.use_attention_mask else None,
|
260
|
-
position_ids if self.use_position_ids else None,
|
282
|
+
attention_mask if self.rbln_config.use_attention_mask else None,
|
283
|
+
position_ids if self.rbln_config.use_position_ids else None,
|
261
284
|
)
|
262
285
|
|
263
286
|
return RBLNDecoderOnlyOutput(logits=logits)
|
@@ -268,7 +291,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
268
291
|
cache_position: torch.Tensor,
|
269
292
|
attention_mask: Optional[torch.Tensor] = None,
|
270
293
|
position_embed: Optional[torch.Tensor] = None,
|
271
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
272
294
|
token_type_ids: Optional[torch.Tensor] = None,
|
273
295
|
):
|
274
296
|
"""
|
@@ -283,15 +305,15 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
283
305
|
)
|
284
306
|
|
285
307
|
query_length = inputs.shape[1]
|
286
|
-
if query_length > self.max_seq_len:
|
308
|
+
if query_length > self.rbln_config.max_seq_len:
|
287
309
|
raise ValueError(
|
288
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
310
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
289
311
|
)
|
290
312
|
|
291
313
|
# Initialize attention mask for chunked processing
|
292
314
|
chunked_attention_mask = (
|
293
|
-
torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
294
|
-
if self.use_attention_mask
|
315
|
+
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
316
|
+
if self.rbln_config.use_attention_mask
|
295
317
|
else None
|
296
318
|
)
|
297
319
|
|
@@ -305,8 +327,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
305
327
|
]
|
306
328
|
|
307
329
|
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
308
|
-
|
309
|
-
|
330
|
+
padding_size = 0
|
331
|
+
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
332
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
310
333
|
# inputs_embeds
|
311
334
|
if inputs.dim() == 3:
|
312
335
|
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
@@ -351,10 +374,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
351
374
|
attention_mask: Optional[torch.Tensor] = None,
|
352
375
|
batch_idx: int = None,
|
353
376
|
block_tables: torch.Tensor = None,
|
354
|
-
is_external_block_tables: bool =
|
377
|
+
is_external_block_tables: bool = False,
|
355
378
|
position_embed: Optional[torch.Tensor] = None,
|
356
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
357
379
|
token_type_ids: Optional[torch.Tensor] = None,
|
380
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
358
381
|
) -> torch.FloatTensor:
|
359
382
|
"""
|
360
383
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -375,39 +398,47 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
375
398
|
)
|
376
399
|
|
377
400
|
# Process input in chunks of size `prefill_chunk_size`
|
378
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
401
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
379
402
|
# Extract the current chunk of inputs and cache positions
|
380
|
-
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
381
|
-
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
403
|
+
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
404
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
382
405
|
position_ids_chunk = (
|
383
|
-
position_ids[:, step : step + self.prefill_chunk_size]
|
406
|
+
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
407
|
+
if position_ids is not None
|
408
|
+
else None
|
384
409
|
)
|
385
410
|
if position_embed is not None:
|
386
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
411
|
+
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
387
412
|
|
388
|
-
if self.use_attention_mask and not self.use_position_ids:
|
413
|
+
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
389
414
|
# Update attention mask to ensure proper causal behavior
|
390
|
-
if step >= self.prefill_chunk_size:
|
391
|
-
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
392
|
-
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
415
|
+
if step >= self.rbln_config.prefill_chunk_size:
|
416
|
+
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
417
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
393
418
|
|
394
419
|
# Define query position
|
395
|
-
|
420
|
+
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
421
|
+
query_position = torch.tensor(
|
422
|
+
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
423
|
+
)
|
424
|
+
else:
|
425
|
+
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
396
426
|
|
397
427
|
# Forward pass for the current chunk
|
398
428
|
logits = super().forward(
|
399
429
|
input_chunk,
|
400
430
|
cache_pos_chunk,
|
401
431
|
block_tables,
|
432
|
+
local_block_tables,
|
402
433
|
position_embed_chunk if position_embed is not None else None,
|
403
434
|
query_position,
|
404
|
-
chunked_attention_mask if self.use_attention_mask else None,
|
405
|
-
position_ids_chunk if self.use_position_ids else None,
|
435
|
+
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
436
|
+
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
406
437
|
out=out_buffers,
|
407
438
|
)
|
408
439
|
|
409
440
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
410
|
-
if not is_external_block_tables and self.use_attention_mask:
|
441
|
+
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
411
442
|
self.dec_attn_mask[batch_idx].fill_(0)
|
412
443
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
413
444
|
|
@@ -477,13 +508,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
477
508
|
dec_attn_mask=dec_attn_mask,
|
478
509
|
block_tables=block_tables,
|
479
510
|
free_block_pool=free_block_pool,
|
480
|
-
|
511
|
+
rbln_config=self.rbln_config,
|
481
512
|
vocab_size=self.config.vocab_size,
|
482
|
-
prefill_chunk_size=self.rbln_config.prefill_chunk_size,
|
483
|
-
max_seq_len=self.rbln_config.max_seq_len,
|
484
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
485
|
-
attn_impl=self.rbln_config.attn_impl,
|
486
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
487
513
|
)
|
488
514
|
|
489
515
|
self.decoders = {}
|
@@ -497,10 +523,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
497
523
|
dec_attn_mask=dec_attn_mask,
|
498
524
|
block_tables=block_tables,
|
499
525
|
free_block_pool=free_block_pool,
|
500
|
-
|
501
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
502
|
-
attn_impl=self.rbln_config.attn_impl,
|
503
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
526
|
+
rbln_config=self.rbln_config,
|
504
527
|
)
|
505
528
|
|
506
529
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -514,10 +537,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
514
537
|
subfolder: str,
|
515
538
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
516
539
|
):
|
517
|
-
|
518
|
-
|
519
|
-
store the torch tensor, weight, etc. in this function.
|
520
|
-
"""
|
540
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
541
|
+
# store the torch tensor, weight, etc. in this function.
|
521
542
|
if rbln_config.use_inputs_embeds:
|
522
543
|
save_dict = {}
|
523
544
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
@@ -625,6 +646,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
625
646
|
"use_attention_mask": rbln_config.use_attention_mask,
|
626
647
|
"use_position_ids": rbln_config.use_position_ids,
|
627
648
|
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
649
|
+
"cache_impl": rbln_config.cache_impl,
|
650
|
+
"sliding_window": rbln_config.sliding_window,
|
651
|
+
"sliding_window_layers": rbln_config.sliding_window_layers,
|
628
652
|
}
|
629
653
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
630
654
|
|
@@ -709,7 +733,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
709
733
|
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
710
734
|
):
|
711
735
|
alloc_memory_by_key[key] += sum(memory_per_node)
|
712
|
-
|
713
736
|
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
714
737
|
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
715
738
|
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
@@ -748,35 +771,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
748
771
|
buffer: Optional[int] = None,
|
749
772
|
num_runtimes: int = 2,
|
750
773
|
) -> int:
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
x = floor(2**21 / b * floor((k - 1) / 2**21))
|
779
|
-
"""
|
774
|
+
# We are finding max_n_blocks(x) that satisfies the following equation:
|
775
|
+
|
776
|
+
# available_dram - kernel_size - buffer
|
777
|
+
# - num_layers * 2 * tensor_parallel_size
|
778
|
+
# * align_2MB(
|
779
|
+
# x
|
780
|
+
# * block_size
|
781
|
+
# * align_64(head_dim)
|
782
|
+
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
783
|
+
# * 2
|
784
|
+
# ) > 0
|
785
|
+
|
786
|
+
# This inequality can be rewritten as follows:
|
787
|
+
|
788
|
+
# a - c * align_2MB(b * x) > 0
|
789
|
+
# where
|
790
|
+
# a = available_dram - kernel_size - buffer
|
791
|
+
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
792
|
+
# c = num_layers * 2 * tensor_parallel_size
|
793
|
+
|
794
|
+
# We can rewrite the inequality as follows:
|
795
|
+
# k > align_2MB(b*x)
|
796
|
+
# where
|
797
|
+
# k = a / c
|
798
|
+
|
799
|
+
# After that, we can derive the following equation:
|
800
|
+
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
780
801
|
|
781
802
|
def align(x: int, nbytes: int) -> int:
|
782
803
|
return int(math.ceil(x / nbytes) * nbytes)
|
@@ -835,22 +856,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
835
856
|
cls,
|
836
857
|
batch_size: int,
|
837
858
|
query_length: int,
|
838
|
-
|
839
|
-
|
840
|
-
use_position_ids: bool,
|
841
|
-
max_seq_len: int,
|
842
|
-
kvcache_block_size: int,
|
843
|
-
kvcache_num_blocks: int,
|
844
|
-
num_key_value_heads: int,
|
845
|
-
num_hidden_layers: int,
|
846
|
-
hidden_size: int,
|
847
|
-
head_dim: int,
|
859
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
860
|
+
model_config: PretrainedConfig,
|
848
861
|
):
|
849
|
-
|
862
|
+
is_prefill: bool = query_length > 1
|
863
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
864
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
865
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
866
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
867
|
+
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
868
|
+
local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
|
869
|
+
|
870
|
+
# 1. main input
|
871
|
+
if rbln_config.use_inputs_embeds:
|
850
872
|
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
851
873
|
else:
|
852
874
|
main_input = ("input_ids", [batch_size, query_length], "int64")
|
853
875
|
|
876
|
+
# 2. cache_position
|
854
877
|
input_info = [
|
855
878
|
main_input,
|
856
879
|
(
|
@@ -860,38 +883,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
860
883
|
),
|
861
884
|
]
|
862
885
|
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
867
|
-
else:
|
868
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
869
|
-
|
870
|
-
if query_length > 1:
|
886
|
+
# 3. block_tables
|
887
|
+
if rbln_config.cache_impl in ["static", "hybrid"]:
|
888
|
+
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
871
889
|
input_info.extend(
|
872
|
-
[
|
873
|
-
("query_position", [], "int16"),
|
874
|
-
]
|
890
|
+
[("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
|
875
891
|
)
|
876
|
-
if
|
892
|
+
if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
|
893
|
+
input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
|
894
|
+
|
895
|
+
# 4. query_position
|
896
|
+
if is_prefill:
|
897
|
+
input_info.extend([("query_position", [], "int16")])
|
898
|
+
|
899
|
+
# 5. attention_mask & position_ids
|
900
|
+
if rbln_config.use_attention_mask:
|
877
901
|
input_info.extend(
|
878
902
|
[
|
879
|
-
("attention_mask", [batch_size,
|
903
|
+
("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
|
904
|
+
if rbln_config.use_position_ids
|
905
|
+
else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
|
880
906
|
]
|
881
907
|
)
|
882
|
-
if use_position_ids:
|
908
|
+
if rbln_config.use_position_ids:
|
883
909
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
884
910
|
|
911
|
+
# 6. past_key_values
|
912
|
+
global_kvcache_shape = [
|
913
|
+
rbln_config.kvcache_num_blocks,
|
914
|
+
num_key_value_heads,
|
915
|
+
rbln_config.kvcache_block_size,
|
916
|
+
head_dim,
|
917
|
+
]
|
918
|
+
local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
885
919
|
input_info.extend(
|
886
920
|
[
|
887
921
|
(
|
888
922
|
f"past_key_values_{i}",
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
kvcache_block_size,
|
893
|
-
head_dim,
|
894
|
-
],
|
923
|
+
local_kvcache_shape
|
924
|
+
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
925
|
+
else global_kvcache_shape,
|
895
926
|
"float32",
|
896
927
|
)
|
897
928
|
for i in range(num_hidden_layers * 2)
|
@@ -900,6 +931,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
900
931
|
|
901
932
|
return input_info
|
902
933
|
|
934
|
+
@classmethod
|
935
|
+
def _update_sliding_window_config(
|
936
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
937
|
+
):
|
938
|
+
# Update the sliding window configuration for the RBLN model.
|
939
|
+
|
940
|
+
# This method must be implemented by subclasses to handle their specific sliding window configurations,
|
941
|
+
# as Hugging Face models use different configuration keys to represent sliding window layers.
|
942
|
+
|
943
|
+
# Args:
|
944
|
+
# model_config (PretrainedConfig): The model configuration from Hugging Face.
|
945
|
+
# rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
|
946
|
+
|
947
|
+
# Notes:
|
948
|
+
# Required configuration settings:
|
949
|
+
# - `cache_impl`: Must be one of:
|
950
|
+
# - "static": All layers use global attention (no sliding window)
|
951
|
+
# - "sliding_window": All layers use sliding window attention
|
952
|
+
# - "hybrid": A mix of global and sliding window attention layers
|
953
|
+
# - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
|
954
|
+
# - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
|
955
|
+
|
956
|
+
# Example implementation for a 'sliding_window' model:
|
957
|
+
# ```python
|
958
|
+
# rbln_config.cache_impl = "sliding_window"
|
959
|
+
# rbln_config.sliding_window = model_config.sliding_window
|
960
|
+
# rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
|
961
|
+
# return rbln_config
|
962
|
+
# ```
|
963
|
+
|
964
|
+
# Returns:
|
965
|
+
# RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
|
966
|
+
|
967
|
+
raise NotImplementedError(
|
968
|
+
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
969
|
+
"See method docstring for required configuration details."
|
970
|
+
)
|
971
|
+
|
903
972
|
@classmethod
|
904
973
|
def _update_rbln_config(
|
905
974
|
cls,
|
@@ -915,6 +984,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
915
984
|
if rbln_config.max_seq_len is None:
|
916
985
|
raise ValueError("`max_seq_len` should be specified.")
|
917
986
|
|
987
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
988
|
+
model_config, "use_sliding_window", True
|
989
|
+
):
|
990
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
991
|
+
if rbln_config.sliding_window is not None:
|
992
|
+
validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
|
993
|
+
|
918
994
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
919
995
|
attn_impl=rbln_config.attn_impl,
|
920
996
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
@@ -963,25 +1039,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
963
1039
|
"This can cause a failure during model compilation."
|
964
1040
|
)
|
965
1041
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
966
|
-
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
967
|
-
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
968
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
969
|
-
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
970
|
-
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
971
1042
|
|
972
1043
|
prefill_input_info = cls.get_input_info(
|
973
1044
|
batch_size=1,
|
974
1045
|
query_length=rbln_config.prefill_chunk_size,
|
975
|
-
|
976
|
-
|
977
|
-
use_position_ids=rbln_config.use_position_ids,
|
978
|
-
max_seq_len=rbln_config.max_seq_len,
|
979
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
980
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
981
|
-
num_key_value_heads=num_key_value_heads,
|
982
|
-
num_hidden_layers=num_hidden_layers,
|
983
|
-
hidden_size=hidden_size,
|
984
|
-
head_dim=head_dim,
|
1046
|
+
rbln_config=rbln_config,
|
1047
|
+
model_config=model_config,
|
985
1048
|
)
|
986
1049
|
|
987
1050
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
@@ -991,16 +1054,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
991
1054
|
dec_input_info = cls.get_input_info(
|
992
1055
|
batch_size=batch_size,
|
993
1056
|
query_length=1,
|
994
|
-
|
995
|
-
|
996
|
-
use_position_ids=rbln_config.use_position_ids,
|
997
|
-
max_seq_len=rbln_config.max_seq_len,
|
998
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
999
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
1000
|
-
num_key_value_heads=num_key_value_heads,
|
1001
|
-
num_hidden_layers=num_hidden_layers,
|
1002
|
-
hidden_size=hidden_size,
|
1003
|
-
head_dim=head_dim,
|
1057
|
+
rbln_config=rbln_config,
|
1058
|
+
model_config=model_config,
|
1004
1059
|
)
|
1005
1060
|
dec_compile_configs.append(
|
1006
1061
|
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
@@ -1124,12 +1179,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1124
1179
|
return_dict: Optional[torch.Tensor] = None,
|
1125
1180
|
**kwargs,
|
1126
1181
|
) -> Tuple[torch.FloatTensor]:
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
"""
|
1182
|
+
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
1183
|
+
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
1184
|
+
# A for-loop ensures synchronization with the HuggingFace generate API.
|
1185
|
+
# The decoder stage operates as usual, processing inputs in batch mode.
|
1186
|
+
|
1133
1187
|
# Prefll
|
1134
1188
|
if cache_position is None:
|
1135
1189
|
logits = []
|