optimum-rbln 0.7.4a9__py3-none-any.whl → 0.7.5a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +11 -7
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
- optimum/rbln/modeling.py +7 -5
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +22 -3
- optimum/rbln/transformers/models/__init__.py +23 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +42 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +251 -135
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
- optimum/rbln/transformers/models/siglip/__init__.py +20 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/submodule.py +13 -1
- {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +39 -28
- {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,7 @@ from ....configuration_utils import RBLNCompileConfig
|
|
30
30
|
from ....modeling import RBLNModel
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
33
|
-
from ...utils.rbln_quantization import
|
33
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
34
34
|
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
35
35
|
from .decoderonly_architecture import (
|
36
36
|
DecoderOnlyWrapper,
|
@@ -59,6 +59,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
59
59
|
kvcache_block_size: int,
|
60
60
|
use_attention_mask: bool,
|
61
61
|
attn_impl: str,
|
62
|
+
use_position_ids: bool,
|
62
63
|
**kwargs: Any,
|
63
64
|
) -> None:
|
64
65
|
super().__init__(runtime, **kwargs)
|
@@ -72,6 +73,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
72
73
|
self.dec_attn_mask = dec_attn_mask
|
73
74
|
self.block_tables = block_tables
|
74
75
|
self.free_block_pool = free_block_pool
|
76
|
+
self.use_position_ids = use_position_ids
|
75
77
|
|
76
78
|
self.kvcache_block_size = kvcache_block_size
|
77
79
|
self.empty_block = -1
|
@@ -164,6 +166,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
164
166
|
batch_idx: Optional[int] = None,
|
165
167
|
block_tables: Optional[torch.Tensor] = None,
|
166
168
|
position_embed: Optional[torch.Tensor] = None,
|
169
|
+
position_ids: Optional[torch.Tensor] = None,
|
167
170
|
):
|
168
171
|
if input_ids is None and inputs_embeds is None:
|
169
172
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -189,10 +192,16 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
189
192
|
is_external_block_tables,
|
190
193
|
attention_mask=attention_mask,
|
191
194
|
position_embed=position_embed,
|
195
|
+
position_ids=position_ids,
|
192
196
|
)
|
193
197
|
else:
|
194
198
|
return self.prefill_forward(
|
195
|
-
inputs,
|
199
|
+
inputs,
|
200
|
+
cache_position,
|
201
|
+
attention_mask,
|
202
|
+
batch_idx,
|
203
|
+
block_tables,
|
204
|
+
position_embed=position_embed,
|
196
205
|
)
|
197
206
|
|
198
207
|
def decode_forward(
|
@@ -203,6 +212,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
203
212
|
is_external_block_tables: bool = None,
|
204
213
|
attention_mask: Optional[torch.Tensor] = None,
|
205
214
|
position_embed: Optional[torch.Tensor] = None,
|
215
|
+
position_ids: Optional[torch.Tensor] = None,
|
206
216
|
) -> torch.FloatTensor:
|
207
217
|
batch_size = inputs.shape[0]
|
208
218
|
if batch_size != self.batch_size:
|
@@ -229,32 +239,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
229
239
|
|
230
240
|
attention_mask = self.dec_attn_mask
|
231
241
|
|
242
|
+
if self.batch_size < block_tables.shape[0]:
|
243
|
+
block_tables = block_tables[: self.batch_size]
|
244
|
+
|
245
|
+
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
246
|
+
attention_mask = attention_mask[: self.batch_size]
|
247
|
+
|
232
248
|
logits = super().forward(
|
233
249
|
inputs,
|
234
250
|
cache_position,
|
235
|
-
attention_mask if self.use_attention_mask else None,
|
236
251
|
block_tables,
|
237
252
|
position_embed,
|
253
|
+
attention_mask if self.use_attention_mask else None,
|
254
|
+
position_ids if self.use_position_ids else None,
|
238
255
|
)
|
239
256
|
|
240
|
-
return logits
|
257
|
+
return RBLNDecoderOnlyOutput(logits=logits)
|
241
258
|
|
242
|
-
def
|
259
|
+
def _prepare_prefill_inputs(
|
243
260
|
self,
|
244
261
|
inputs: torch.Tensor,
|
245
|
-
cache_position: torch.Tensor
|
262
|
+
cache_position: torch.Tensor,
|
246
263
|
attention_mask: Optional[torch.Tensor] = None,
|
247
|
-
batch_idx: int = None,
|
248
|
-
block_tables: torch.Tensor = None,
|
249
|
-
is_external_block_tables: bool = None,
|
250
264
|
position_embed: Optional[torch.Tensor] = None,
|
251
|
-
)
|
265
|
+
):
|
252
266
|
"""
|
253
|
-
|
254
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
255
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
267
|
+
Prepare inputs for prefill phase.
|
256
268
|
"""
|
257
|
-
|
258
269
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
259
270
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
260
271
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
@@ -270,8 +281,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
270
281
|
)
|
271
282
|
|
272
283
|
# Initialize attention mask for chunked processing
|
273
|
-
|
274
|
-
|
284
|
+
chunked_attention_mask = (
|
285
|
+
torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
286
|
+
if self.use_attention_mask
|
287
|
+
else None
|
288
|
+
)
|
275
289
|
|
276
290
|
# Buffer for storing output logits
|
277
291
|
out_buffers = [
|
@@ -282,36 +296,80 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
282
296
|
)
|
283
297
|
]
|
284
298
|
|
285
|
-
#
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
else:
|
295
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
299
|
+
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
300
|
+
if query_length % self.prefill_chunk_size != 0:
|
301
|
+
padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
|
302
|
+
# inputs_embeds
|
303
|
+
if inputs.dim() == 3:
|
304
|
+
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
305
|
+
# inputs_ids
|
306
|
+
else:
|
307
|
+
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
296
308
|
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
309
|
+
cache_position = torch.cat(
|
310
|
+
[
|
311
|
+
cache_position,
|
312
|
+
torch.arange(
|
313
|
+
query_length,
|
314
|
+
query_length + padding_size,
|
315
|
+
dtype=torch.int32,
|
316
|
+
).unsqueeze(0),
|
317
|
+
],
|
318
|
+
dim=-1,
|
319
|
+
)
|
320
|
+
|
321
|
+
if position_embed is not None:
|
322
|
+
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
323
|
+
|
324
|
+
# Overwrite position_ids and padded_cache_lengths
|
325
|
+
position_ids = None
|
326
|
+
padded_cache_lengths = 0
|
327
|
+
|
328
|
+
return (
|
329
|
+
inputs,
|
330
|
+
cache_position,
|
331
|
+
chunked_attention_mask,
|
332
|
+
out_buffers,
|
333
|
+
position_ids,
|
334
|
+
position_embed,
|
335
|
+
padded_cache_lengths,
|
336
|
+
query_length,
|
337
|
+
)
|
308
338
|
|
309
|
-
|
310
|
-
|
339
|
+
def prefill_forward(
|
340
|
+
self,
|
341
|
+
inputs: torch.Tensor,
|
342
|
+
cache_position: torch.Tensor = None,
|
343
|
+
attention_mask: Optional[torch.Tensor] = None,
|
344
|
+
batch_idx: int = None,
|
345
|
+
block_tables: torch.Tensor = None,
|
346
|
+
is_external_block_tables: bool = None,
|
347
|
+
position_embed: Optional[torch.Tensor] = None,
|
348
|
+
) -> torch.FloatTensor:
|
349
|
+
"""
|
350
|
+
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
351
|
+
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
352
|
+
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
353
|
+
"""
|
354
|
+
(
|
355
|
+
inputs,
|
356
|
+
cache_position,
|
357
|
+
chunked_attention_mask,
|
358
|
+
out_buffers,
|
359
|
+
position_ids,
|
360
|
+
position_embed,
|
361
|
+
padded_cache_lengths,
|
362
|
+
query_length,
|
363
|
+
) = self._prepare_prefill_inputs(inputs, cache_position, attention_mask, position_embed)
|
311
364
|
|
365
|
+
# Process input in chunks of size `prefill_chunk_size`
|
366
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
312
367
|
# Extract the current chunk of inputs and cache positions
|
313
368
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
314
369
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
370
|
+
position_ids_chunk = (
|
371
|
+
position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
|
372
|
+
)
|
315
373
|
if position_embed is not None:
|
316
374
|
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
317
375
|
|
@@ -328,9 +386,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
328
386
|
logits = super().forward(
|
329
387
|
input_chunk,
|
330
388
|
cache_pos_chunk,
|
331
|
-
chunked_attention_mask if self.use_attention_mask else None,
|
332
|
-
query_position,
|
333
389
|
block_tables,
|
390
|
+
query_position,
|
391
|
+
chunked_attention_mask if self.use_attention_mask else None,
|
392
|
+
position_ids_chunk if position_ids is not None else None,
|
334
393
|
position_embed_chunk if position_embed is not None else None,
|
335
394
|
out=out_buffers,
|
336
395
|
)
|
@@ -340,13 +399,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
340
399
|
self.dec_attn_mask[batch_idx].fill_(0)
|
341
400
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
342
401
|
|
343
|
-
return logits
|
402
|
+
return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
344
403
|
|
345
404
|
|
346
405
|
@dataclass
|
347
406
|
class RBLNDecoderOnlyOutput(ModelOutput):
|
348
407
|
logits: torch.FloatTensor = None
|
349
408
|
generate_idx: torch.Tensor = None
|
409
|
+
padded_cache_lengths: int = None
|
350
410
|
|
351
411
|
|
352
412
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
@@ -416,20 +476,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
416
476
|
max_seq_len=self.rbln_config.max_seq_len,
|
417
477
|
use_attention_mask=self.rbln_config.use_attention_mask,
|
418
478
|
attn_impl=self.rbln_config.attn_impl,
|
479
|
+
use_position_ids=self.rbln_config.use_position_ids,
|
419
480
|
)
|
420
|
-
self.
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
481
|
+
self.decoders = {}
|
482
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
483
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
484
|
+
runtime=self.model[i + 1],
|
485
|
+
main_input_name=main_input_name,
|
486
|
+
embed_tokens=self.embed_tokens,
|
487
|
+
phase="decode",
|
488
|
+
batch_size=batch_size,
|
489
|
+
dec_attn_mask=dec_attn_mask,
|
490
|
+
block_tables=block_tables,
|
491
|
+
free_block_pool=free_block_pool,
|
492
|
+
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
493
|
+
use_attention_mask=self.rbln_config.use_attention_mask,
|
494
|
+
attn_impl=self.rbln_config.attn_impl,
|
495
|
+
use_position_ids=self.rbln_config.use_position_ids,
|
496
|
+
)
|
497
|
+
|
498
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
499
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
433
500
|
|
434
501
|
@classmethod
|
435
502
|
def save_torch_artifacts(
|
@@ -471,8 +538,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
471
538
|
trust_remote_code: bool = False,
|
472
539
|
**kwargs,
|
473
540
|
):
|
474
|
-
from ...utils.rbln_quantization import prepare_model_for_quantization
|
475
|
-
|
476
541
|
kwargs = cls.update_kwargs(kwargs)
|
477
542
|
|
478
543
|
if config is None:
|
@@ -489,8 +554,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
489
554
|
with no_init_weights():
|
490
555
|
model = AutoModelForCausalLM.from_config(config)
|
491
556
|
|
492
|
-
|
493
|
-
|
557
|
+
model = prepare_model_for_quantization(
|
558
|
+
model,
|
559
|
+
model_id,
|
560
|
+
kwargs.get("num_hidden_layers"),
|
561
|
+
use_auth_token=use_auth_token,
|
562
|
+
revision=revision,
|
563
|
+
cache_dir=cache_dir,
|
564
|
+
force_download=force_download,
|
565
|
+
local_files_only=local_files_only,
|
566
|
+
)
|
494
567
|
return model
|
495
568
|
|
496
569
|
def __getattr__(self, __name: str) -> Any:
|
@@ -517,11 +590,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
517
590
|
def get_pytorch_model(
|
518
591
|
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
519
592
|
) -> "PreTrainedModel":
|
520
|
-
if
|
521
|
-
rbln_config is not None
|
522
|
-
and "format" in rbln_config.quantization
|
523
|
-
and rbln_config.quantization["format"] == "rbln"
|
524
|
-
):
|
593
|
+
if rbln_config and rbln_config.quantization:
|
525
594
|
model = cls.get_quantized_model(*args, **kwargs)
|
526
595
|
else:
|
527
596
|
model = super().get_pytorch_model(*args, **kwargs)
|
@@ -537,6 +606,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
537
606
|
"kvcache_block_size": rbln_config.kvcache_block_size,
|
538
607
|
"use_rotary_emb": cls._use_rotary_emb,
|
539
608
|
"use_attention_mask": rbln_config.use_attention_mask,
|
609
|
+
"use_position_ids": rbln_config.use_position_ids,
|
610
|
+
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
540
611
|
}
|
541
612
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
542
613
|
|
@@ -547,7 +618,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
547
618
|
|
548
619
|
rbln_compile_configs = rbln_config.compile_cfgs
|
549
620
|
prefill_compile_config = rbln_compile_configs[0]
|
550
|
-
dec_compile_config = rbln_compile_configs[1]
|
551
621
|
|
552
622
|
context = CompileContext(use_weight_sharing=True)
|
553
623
|
|
@@ -562,33 +632,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
562
632
|
static_tensors[name] = tensor
|
563
633
|
context.mark_static_address(tensor)
|
564
634
|
|
565
|
-
|
566
|
-
|
567
|
-
@QuantizationManager.with_quantization_env
|
568
|
-
def compile_model(*args, **kwargs):
|
635
|
+
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
569
636
|
try:
|
637
|
+
if quantization:
|
638
|
+
quantization.maybe_set_quantization_env()
|
570
639
|
original_linear = torch.nn.functional.linear
|
571
640
|
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
572
|
-
|
573
|
-
compiled_prefill = RBLNModel.compile(
|
574
|
-
wrapped_model,
|
575
|
-
prefill_compile_config,
|
576
|
-
example_inputs=prefill_example_inputs,
|
577
|
-
compile_context=context,
|
578
|
-
)
|
579
|
-
|
580
|
-
wrapped_model.phase = "decode"
|
581
|
-
compiled_decoder = RBLNModel.compile(
|
641
|
+
compiled_model = RBLNModel.compile(
|
582
642
|
wrapped_model,
|
583
|
-
|
584
|
-
example_inputs=
|
585
|
-
compile_context=
|
643
|
+
compile_config,
|
644
|
+
example_inputs=example_inputs,
|
645
|
+
compile_context=compile_context,
|
586
646
|
)
|
587
|
-
return
|
647
|
+
return compiled_model
|
588
648
|
finally:
|
589
649
|
torch.nn.functional.linear = original_linear
|
650
|
+
if quantization:
|
651
|
+
quantization.maybe_reset_quantization_env()
|
590
652
|
|
591
|
-
|
653
|
+
wrapped_model.phase = "prefill"
|
654
|
+
compiled_prefill = compile_model(
|
655
|
+
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
656
|
+
)
|
657
|
+
|
658
|
+
wrapped_model.phase = "decode"
|
659
|
+
compiled_models = {"prefill": compiled_prefill}
|
660
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
661
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
662
|
+
compiled_decoder = compile_model(
|
663
|
+
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
664
|
+
)
|
665
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
592
666
|
|
593
667
|
# check if the memory is enough to have additional blocks
|
594
668
|
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
@@ -613,8 +687,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
613
687
|
alloc_memory_by_key: Dict[str, int] = {
|
614
688
|
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
615
689
|
}
|
616
|
-
for
|
617
|
-
|
690
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
691
|
+
for key, memory_per_node in (
|
692
|
+
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
693
|
+
):
|
694
|
+
alloc_memory_by_key[key] += sum(memory_per_node)
|
618
695
|
alloc_memory_by_key.pop("PortRecur") # kv-cache
|
619
696
|
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
620
697
|
|
@@ -650,6 +727,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
650
727
|
n_model_params: Optional[int] = None,
|
651
728
|
kernel_size: Optional[int] = None,
|
652
729
|
buffer: Optional[int] = None,
|
730
|
+
num_runtimes: int = 2,
|
653
731
|
) -> int:
|
654
732
|
"""
|
655
733
|
We are finding max_n_blocks(x) that satisfies the following equation:
|
@@ -721,7 +799,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
721
799
|
|
722
800
|
if buffer is None:
|
723
801
|
# TODO: Accurate buffer estimation
|
724
|
-
|
802
|
+
buffer_per_runtime_per_core = 2**28 # 256MB per runtime
|
803
|
+
buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
|
725
804
|
buffer = buffer_per_core * tensor_parallel_size
|
726
805
|
available_dram -= buffer
|
727
806
|
|
@@ -739,6 +818,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
739
818
|
query_length: int,
|
740
819
|
use_inputs_embeds: bool,
|
741
820
|
use_attention_mask: bool,
|
821
|
+
use_position_ids: bool,
|
742
822
|
max_seq_len: int,
|
743
823
|
kvcache_block_size: int,
|
744
824
|
kvcache_num_blocks: int,
|
@@ -761,26 +841,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
761
841
|
),
|
762
842
|
]
|
763
843
|
|
764
|
-
|
844
|
+
max_block_cnt = max_seq_len // kvcache_block_size
|
845
|
+
|
846
|
+
if query_length > 1:
|
847
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
848
|
+
else:
|
849
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
850
|
+
|
851
|
+
if query_length > 1:
|
765
852
|
input_info.extend(
|
766
853
|
[
|
767
|
-
("
|
854
|
+
("query_position", [], "int16"),
|
768
855
|
]
|
769
856
|
)
|
770
|
-
|
771
|
-
if query_length > 1:
|
857
|
+
if use_attention_mask:
|
772
858
|
input_info.extend(
|
773
859
|
[
|
774
|
-
("
|
860
|
+
("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
|
775
861
|
]
|
776
862
|
)
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
if query_length > 1:
|
781
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
782
|
-
else:
|
783
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
863
|
+
if use_position_ids:
|
864
|
+
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
784
865
|
|
785
866
|
input_info.extend(
|
786
867
|
[
|
@@ -839,6 +920,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
839
920
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
840
921
|
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
841
922
|
n_model_params=sum(p.numel() for p in model.parameters()),
|
923
|
+
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
842
924
|
)
|
843
925
|
|
844
926
|
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
@@ -873,19 +955,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
873
955
|
query_length=rbln_config.prefill_chunk_size,
|
874
956
|
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
875
957
|
use_attention_mask=rbln_config.use_attention_mask,
|
876
|
-
|
877
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
878
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
879
|
-
num_key_value_heads=num_key_value_heads,
|
880
|
-
num_hidden_layers=num_hidden_layers,
|
881
|
-
hidden_size=hidden_size,
|
882
|
-
head_dim=head_dim,
|
883
|
-
)
|
884
|
-
dec_input_info = cls.get_input_info(
|
885
|
-
batch_size=rbln_config.batch_size,
|
886
|
-
query_length=1,
|
887
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
888
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
958
|
+
use_position_ids=rbln_config.use_position_ids,
|
889
959
|
max_seq_len=rbln_config.max_seq_len,
|
890
960
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
891
961
|
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
@@ -896,9 +966,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
896
966
|
)
|
897
967
|
|
898
968
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
899
|
-
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
900
969
|
|
901
|
-
|
970
|
+
dec_compile_configs = []
|
971
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
972
|
+
dec_input_info = cls.get_input_info(
|
973
|
+
batch_size=batch_size,
|
974
|
+
query_length=1,
|
975
|
+
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
976
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
977
|
+
use_position_ids=rbln_config.use_position_ids,
|
978
|
+
max_seq_len=rbln_config.max_seq_len,
|
979
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
980
|
+
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
981
|
+
num_key_value_heads=num_key_value_heads,
|
982
|
+
num_hidden_layers=num_hidden_layers,
|
983
|
+
hidden_size=hidden_size,
|
984
|
+
head_dim=head_dim,
|
985
|
+
)
|
986
|
+
dec_compile_configs.append(
|
987
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
988
|
+
)
|
989
|
+
rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
|
902
990
|
|
903
991
|
return rbln_config
|
904
992
|
|
@@ -908,8 +996,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
908
996
|
compiled_models: List[rebel.RBLNCompiledModel],
|
909
997
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
910
998
|
) -> List[rebel.Runtime]:
|
911
|
-
|
912
|
-
|
999
|
+
expected_model_names = [
|
1000
|
+
"prefill",
|
1001
|
+
*[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
|
1002
|
+
]
|
1003
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
1004
|
+
cls._raise_missing_compiled_file_error(expected_model_names)
|
913
1005
|
|
914
1006
|
return [
|
915
1007
|
rebel.Runtime(
|
@@ -918,12 +1010,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
918
1010
|
device=rbln_config.device_map["prefill"],
|
919
1011
|
activate_profiler=rbln_config.activate_profiler,
|
920
1012
|
),
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
1013
|
+
*[
|
1014
|
+
rebel.Runtime(
|
1015
|
+
compiled_models[i + 1],
|
1016
|
+
tensor_type="pt",
|
1017
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
1018
|
+
activate_profiler=rbln_config.activate_profiler,
|
1019
|
+
)
|
1020
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
1021
|
+
],
|
927
1022
|
]
|
928
1023
|
|
929
1024
|
def get_decoder(self):
|
@@ -941,6 +1036,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
941
1036
|
generate_idx: Optional[torch.Tensor] = None,
|
942
1037
|
attention_mask: Optional[torch.LongTensor] = None,
|
943
1038
|
inputs_embeds: Optional[torch.Tensor] = None,
|
1039
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
944
1040
|
**kwargs,
|
945
1041
|
):
|
946
1042
|
model_inputs = {}
|
@@ -948,13 +1044,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
948
1044
|
|
949
1045
|
if is_prefill_phase:
|
950
1046
|
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
1047
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
951
1048
|
cache_position = None
|
1049
|
+
position_ids = None
|
952
1050
|
else:
|
953
1051
|
if inputs_embeds is not None:
|
954
|
-
|
1052
|
+
# if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
|
1053
|
+
inputs_embeds = None
|
955
1054
|
|
956
1055
|
input_ids = input_ids[:, -1:]
|
957
|
-
|
1056
|
+
position_ids = generate_idx
|
1057
|
+
cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
|
958
1058
|
generate_idx = generate_idx + 1
|
959
1059
|
model_inputs.update({"input_ids": input_ids})
|
960
1060
|
|
@@ -973,6 +1073,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
973
1073
|
"attention_mask": attention_mask,
|
974
1074
|
"cache_position": cache_position,
|
975
1075
|
"generate_idx": generate_idx,
|
1076
|
+
"position_ids": position_ids,
|
1077
|
+
"padded_cache_lengths": padded_cache_lengths,
|
976
1078
|
}
|
977
1079
|
)
|
978
1080
|
|
@@ -986,6 +1088,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
986
1088
|
) -> Dict[str, Any]:
|
987
1089
|
# update generate_idx
|
988
1090
|
model_kwargs["generate_idx"] = outputs.generate_idx
|
1091
|
+
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
989
1092
|
|
990
1093
|
return model_kwargs
|
991
1094
|
|
@@ -996,6 +1099,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
996
1099
|
cache_position: Optional[torch.Tensor] = None,
|
997
1100
|
attention_mask: Optional[torch.LongTensor] = None,
|
998
1101
|
generate_idx: Optional[torch.Tensor] = None,
|
1102
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
1103
|
+
position_ids: Optional[torch.Tensor] = None,
|
1104
|
+
return_dict: Optional[torch.Tensor] = None,
|
999
1105
|
**kwargs,
|
1000
1106
|
) -> Tuple[torch.FloatTensor]:
|
1001
1107
|
"""
|
@@ -1009,28 +1115,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
1009
1115
|
logits = []
|
1010
1116
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
1011
1117
|
batch_size = inputs.shape[0]
|
1012
|
-
|
1013
1118
|
for b_idx in range(batch_size):
|
1014
1119
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
1015
|
-
|
1120
|
+
output = self.prefill_decoder(
|
1016
1121
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
1017
1122
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
1018
1123
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
1019
1124
|
cache_position=cache_position,
|
1020
1125
|
batch_idx=b_idx,
|
1021
1126
|
)
|
1022
|
-
|
1023
|
-
|
1127
|
+
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
1128
|
+
logits.append(output.logits)
|
1024
1129
|
logits = torch.cat(logits, dim=0)
|
1025
1130
|
# Decoder
|
1026
1131
|
else:
|
1027
|
-
|
1132
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
1133
|
+
batch_size = inputs.shape[0]
|
1134
|
+
if batch_size not in self.decoders:
|
1135
|
+
raise ValueError(
|
1136
|
+
f"No decoder runtime available for batch size {batch_size}. "
|
1137
|
+
f"Available batch sizes are: {list(self.decoders.keys())}. "
|
1138
|
+
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
1139
|
+
)
|
1140
|
+
logits = self.decoders[batch_size](
|
1028
1141
|
input_ids=input_ids,
|
1029
1142
|
inputs_embeds=inputs_embeds,
|
1030
1143
|
cache_position=cache_position,
|
1031
|
-
|
1144
|
+
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
1145
|
+
).logits
|
1032
1146
|
|
1033
|
-
|
1034
|
-
logits
|
1035
|
-
|
1036
|
-
|
1147
|
+
if not return_dict:
|
1148
|
+
return logits, generate_idx, padded_cache_lengths
|
1149
|
+
else:
|
1150
|
+
return RBLNDecoderOnlyOutput(
|
1151
|
+
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
1152
|
+
)
|