optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. optimum/rbln/__init__.py +164 -36
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +107 -78
  31. optimum/rbln/transformers/__init__.py +87 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +108 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  76. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  77. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  78. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  79. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  80. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  81. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  82. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  83. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  84. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  85. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  86. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  87. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  88. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  89. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  90. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  91. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  92. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  93. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  94. optimum/rbln/utils/runtime_utils.py +33 -2
  95. optimum/rbln/utils/submodule.py +26 -43
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
  97. optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
  98. optimum/rbln/modeling_config.py +0 -310
  99. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  100. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
  101. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -26,13 +26,15 @@ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, Pre
26
26
  from transformers.modeling_utils import no_init_weights
27
27
  from transformers.utils import ModelOutput
28
28
 
29
+ from ....configuration_utils import RBLNCompileConfig
29
30
  from ....modeling import RBLNModel
30
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
33
  from ...utils.rbln_quantization import QuantizationManager
34
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
34
35
  from .decoderonly_architecture import (
35
36
  DecoderOnlyWrapper,
37
+ set_default_values,
36
38
  validate_attention_method,
37
39
  )
38
40
 
@@ -161,6 +163,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
161
163
  attention_mask: Optional[torch.Tensor] = None,
162
164
  batch_idx: Optional[int] = None,
163
165
  block_tables: Optional[torch.Tensor] = None,
166
+ position_embed: Optional[torch.Tensor] = None,
164
167
  ):
165
168
  if input_ids is None and inputs_embeds is None:
166
169
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -185,9 +188,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
185
188
  block_tables,
186
189
  is_external_block_tables,
187
190
  attention_mask=attention_mask,
191
+ position_embed=position_embed,
188
192
  )
189
193
  else:
190
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
194
+ return self.prefill_forward(
195
+ inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
196
+ )
191
197
 
192
198
  def decode_forward(
193
199
  self,
@@ -196,6 +202,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
196
202
  block_tables: torch.Tensor = None,
197
203
  is_external_block_tables: bool = None,
198
204
  attention_mask: Optional[torch.Tensor] = None,
205
+ position_embed: Optional[torch.Tensor] = None,
199
206
  ) -> torch.FloatTensor:
200
207
  batch_size = inputs.shape[0]
201
208
  if batch_size != self.batch_size:
@@ -227,6 +234,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
227
234
  cache_position,
228
235
  attention_mask if self.use_attention_mask else None,
229
236
  block_tables,
237
+ position_embed,
230
238
  )
231
239
 
232
240
  return logits
@@ -239,6 +247,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
239
247
  batch_idx: int = None,
240
248
  block_tables: torch.Tensor = None,
241
249
  is_external_block_tables: bool = None,
250
+ position_embed: Optional[torch.Tensor] = None,
242
251
  ) -> torch.FloatTensor:
243
252
  """
244
253
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -249,6 +258,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
249
258
  # Handle continuous batching in a compiled graph by extracting valid inputs
250
259
  # If an attention mask is provided, select only the valid (non-masked) inputs
251
260
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
261
+ if position_embed is not None:
262
+ position_embed = (
263
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
264
+ )
252
265
 
253
266
  query_length = inputs.shape[1]
254
267
  if query_length > self.max_seq_len:
@@ -293,9 +306,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
293
306
  dim=-1,
294
307
  )
295
308
 
309
+ if position_embed is not None:
310
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
311
+
296
312
  # Extract the current chunk of inputs and cache positions
297
313
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
298
314
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
315
+ if position_embed is not None:
316
+ position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
299
317
 
300
318
  if self.use_attention_mask:
301
319
  # Update attention mask to ensure proper causal behavior
@@ -313,6 +331,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
313
331
  chunked_attention_mask if self.use_attention_mask else None,
314
332
  query_position,
315
333
  block_tables,
334
+ position_embed_chunk if position_embed is not None else None,
316
335
  out=out_buffers,
317
336
  )
318
337
 
@@ -356,17 +375,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
356
375
  _use_rotary_emb = True
357
376
 
358
377
  def __post_init__(self, **kwargs):
359
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
360
- self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
361
- self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
362
- self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
363
- # FIXME get kvcache_num_blocks from compiled results.
364
- self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
365
- self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
366
- attn_impl = self.rbln_config.model_cfg["attn_impl"]
367
378
  main_input_name = self.main_input_name
368
379
 
369
- if self.rbln_config.model_cfg["use_inputs_embeds"]:
380
+ if self.rbln_config.use_inputs_embeds:
370
381
  main_input_name = "inputs_embeds"
371
382
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
372
383
  with no_init_weights():
@@ -380,40 +391,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
380
391
  self.embed_tokens = None
381
392
 
382
393
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
383
- dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
394
+ dec_attn_mask = torch.zeros(
395
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
396
+ )
384
397
  block_tables = torch.zeros(
385
- self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
398
+ self.rbln_config.batch_size,
399
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
400
+ dtype=torch.int16,
386
401
  ).fill_(-1)
387
- free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
402
+ free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
388
403
 
389
404
  self.prefill_decoder = RBLNRuntimeModel(
390
405
  runtime=self.model[0],
391
406
  main_input_name=main_input_name,
392
407
  embed_tokens=self.embed_tokens,
393
408
  phase="prefill",
394
- batch_size=self.batch_size,
409
+ batch_size=self.rbln_config.batch_size,
395
410
  dec_attn_mask=dec_attn_mask,
396
411
  block_tables=block_tables,
397
412
  free_block_pool=free_block_pool,
398
- kvcache_block_size=self.kvcache_block_size,
413
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
399
414
  vocab_size=self.config.vocab_size,
400
- prefill_chunk_size=self.prefill_chunk_size,
401
- max_seq_len=self.max_seq_len,
402
- use_attention_mask=self.use_attention_mask,
403
- attn_impl=attn_impl,
415
+ prefill_chunk_size=self.rbln_config.prefill_chunk_size,
416
+ max_seq_len=self.rbln_config.max_seq_len,
417
+ use_attention_mask=self.rbln_config.use_attention_mask,
418
+ attn_impl=self.rbln_config.attn_impl,
404
419
  )
405
420
  self.decoder = RBLNRuntimeModel(
406
421
  runtime=self.model[1],
407
422
  main_input_name=main_input_name,
408
423
  embed_tokens=self.embed_tokens,
409
424
  phase="decode",
410
- batch_size=self.batch_size,
425
+ batch_size=self.rbln_config.batch_size,
411
426
  dec_attn_mask=dec_attn_mask,
412
427
  block_tables=block_tables,
413
428
  free_block_pool=free_block_pool,
414
- kvcache_block_size=self.kvcache_block_size,
415
- use_attention_mask=self.use_attention_mask,
416
- attn_impl=attn_impl,
429
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
430
+ use_attention_mask=self.rbln_config.use_attention_mask,
431
+ attn_impl=self.rbln_config.attn_impl,
417
432
  )
418
433
 
419
434
  @classmethod
@@ -422,13 +437,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
422
437
  model: "PreTrainedModel",
423
438
  save_dir_path: Path,
424
439
  subfolder: str,
425
- rbln_config: RBLNConfig,
440
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
426
441
  ):
427
442
  """
428
443
  If you are unavoidably running on a CPU rather than an RBLN device,
429
444
  store the torch tensor, weight, etc. in this function.
430
445
  """
431
- if rbln_config.model_cfg["use_inputs_embeds"]:
446
+ if rbln_config.use_inputs_embeds:
432
447
  save_dict = {}
433
448
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
434
449
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -493,33 +508,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
493
508
  return val
494
509
 
495
510
  @classmethod
496
- def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
497
- logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
498
-
499
- rbln_kwargs = kwargs.get("rbln_kwargs", {})
500
- rbln_quantization = rbln_kwargs.get("quantization", None)
501
- if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
511
+ def get_pytorch_model(
512
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
513
+ ) -> "PreTrainedModel":
514
+ if (
515
+ rbln_config is not None
516
+ and "format" in rbln_config.quantization
517
+ and rbln_config.quantization["format"] == "rbln"
518
+ ):
502
519
  model = cls.get_quantized_model(*args, **kwargs)
503
520
  else:
504
521
  model = super().get_pytorch_model(*args, **kwargs)
505
522
 
506
- logger.debug("Loaded the LLM model to the CPU.")
507
523
  return model
508
524
 
509
525
  @classmethod
510
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
511
- wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
512
- wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
513
- wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
514
- wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
515
- wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
516
- wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
517
-
526
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
527
+ wrapper_cfg = {
528
+ "max_seq_len": rbln_config.max_seq_len,
529
+ "attn_impl": rbln_config.attn_impl,
530
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
531
+ "kvcache_block_size": rbln_config.kvcache_block_size,
532
+ "use_rotary_emb": cls._use_rotary_emb,
533
+ "use_attention_mask": rbln_config.use_attention_mask,
534
+ }
518
535
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
519
536
 
520
537
  @classmethod
521
538
  @torch.inference_mode()
522
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
539
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
523
540
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
524
541
 
525
542
  rbln_compile_configs = rbln_config.compile_cfgs
@@ -541,8 +558,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
541
558
 
542
559
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
543
560
 
544
- quantize_config = rbln_config.model_cfg.get("quantization", None)
545
-
546
561
  @QuantizationManager.with_quantization_env
547
562
  def compile_model(*args, **kwargs):
548
563
  try:
@@ -567,7 +582,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
567
582
  finally:
568
583
  torch.nn.functional.linear = original_linear
569
584
 
570
- return compile_model(quantize_config=quantize_config)
585
+ compiled_models = compile_model(quantize_config=rbln_config.quantization)
586
+
587
+ # check if the memory is enough to have additional blocks
588
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
589
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
590
+ cls.maybe_suggest_kvcache_num_blocks(
591
+ compiled_models=compiled_models,
592
+ model_config=model.config,
593
+ rbln_config=rbln_config,
594
+ )
595
+
596
+ return compiled_models
597
+
598
+ @classmethod
599
+ def maybe_suggest_kvcache_num_blocks(
600
+ cls,
601
+ compiled_models: Dict[str, rebel.RBLNCompiledModel],
602
+ model_config: PretrainedConfig,
603
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
604
+ ) -> None:
605
+ # Get the actual memory allocation of each node by key
606
+ alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
607
+ alloc_memory_by_key: Dict[str, int] = {
608
+ key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
609
+ }
610
+ for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
611
+ alloc_memory_by_key[key] += sum(memory_per_node)
612
+ alloc_memory_by_key.pop("PortRecur") # kv-cache
613
+ kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
614
+
615
+ # Get the maximum number of blocks that can be allocated
616
+ buffer = sum(alloc_memory_by_key.values())
617
+ max_num_blocks = cls.get_maximum_num_blocks(
618
+ config=model_config,
619
+ tensor_parallel_size=rbln_config.tensor_parallel_size,
620
+ kvcache_block_size=rbln_config.kvcache_block_size,
621
+ kernel_size=kernel_size,
622
+ buffer=buffer,
623
+ )
624
+
625
+ # Since our estimation logic is not always accurate,
626
+ # users can set `kvcache_num_blocks` to `max_num_blocks`.
627
+ # If the memory is not enough, the model will fail to compile.
628
+ if rbln_config.kvcache_num_blocks < max_num_blocks:
629
+ logger.warning(
630
+ f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
631
+ "Our analysis indicates that additional memory is available for more blocks. "
632
+ f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
633
+ "Please be advised that our memory estimation algorithm has limitations, "
634
+ "and increasing this value may not guarantee successful model compilation."
635
+ )
571
636
 
572
637
  @classmethod
573
638
  def get_maximum_num_blocks(
@@ -575,8 +640,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
575
640
  config: PretrainedConfig,
576
641
  tensor_parallel_size: int,
577
642
  kvcache_block_size: int,
578
- nbits_per_param: int,
579
- n_model_params: int,
643
+ nbits_per_param: Optional[int] = None,
644
+ n_model_params: Optional[int] = None,
645
+ kernel_size: Optional[int] = None,
646
+ buffer: Optional[int] = None,
580
647
  ) -> int:
581
648
  """
582
649
  We are finding max_n_blocks(x) that satisfies the following equation:
@@ -626,24 +693,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
626
693
  ATOM_SYS_DRAM_NBYTES = 288 * 2**20
627
694
  available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
628
695
 
629
- # Get estimated kernel size (approximated)
630
- lm_heads_params = align(vocab_size, 64) * hidden_size
631
- lm_heads_nbytes = (
632
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
633
- )
634
- params = n_model_params - lm_heads_params
635
- layer_nbytes = (
636
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
637
- * num_layers
638
- * tensor_parallel_size
639
- )
640
- kernel_size = layer_nbytes + lm_heads_nbytes
696
+ if kernel_size is None:
697
+ if n_model_params is None:
698
+ raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
699
+ # Get estimated kernel size (approximated)
700
+ lm_heads_params = align(vocab_size, 64) * hidden_size
701
+ lm_heads_nbytes = (
702
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
703
+ )
704
+ params = n_model_params - lm_heads_params
705
+ layer_nbytes = (
706
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
707
+ * num_layers
708
+ * tensor_parallel_size
709
+ )
710
+ kernel_size = layer_nbytes + lm_heads_nbytes
711
+ elif n_model_params is not None:
712
+ raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
641
713
 
642
714
  available_dram -= kernel_size
643
715
 
644
- # TODO: Accurate buffer estimation
645
- buffer_per_core = 2**29 # 500MB per npu
646
- buffer = buffer_per_core * tensor_parallel_size
716
+ if buffer is None:
717
+ # TODO: Accurate buffer estimation
718
+ buffer_per_core = 2**29 # 500MB per npu
719
+ buffer = buffer_per_core * tensor_parallel_size
647
720
  available_dram -= buffer
648
721
 
649
722
  b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
@@ -654,184 +727,172 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
654
727
  return max_n_blocks
655
728
 
656
729
  @classmethod
657
- def _get_rbln_config(
730
+ def get_input_info(
658
731
  cls,
659
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
660
- model_config: "PretrainedConfig",
661
- rbln_kwargs: Dict[str, Any] = {},
662
- ) -> RBLNConfig:
663
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
664
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
665
- rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
666
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
667
- rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
668
- rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
669
- rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
670
- rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
671
- rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
672
-
673
- if rbln_use_attention_mask is None:
674
- rbln_use_attention_mask = False
675
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
676
- if rbln_npu == "RBLN-CA02":
677
- rbln_use_attention_mask = True
678
-
679
- if rbln_prefill_chunk_size is None:
680
- rbln_prefill_chunk_size = 128
681
- elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
682
- raise ValueError(
683
- f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
732
+ batch_size: int,
733
+ query_length: int,
734
+ use_inputs_embeds: bool,
735
+ use_attention_mask: bool,
736
+ max_seq_len: int,
737
+ kvcache_block_size: int,
738
+ kvcache_num_blocks: int,
739
+ num_key_value_heads: int,
740
+ num_hidden_layers: int,
741
+ hidden_size: int,
742
+ head_dim: int,
743
+ ):
744
+ if use_inputs_embeds:
745
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
746
+ else:
747
+ main_input = ("input_ids", [batch_size, query_length], "int64")
748
+
749
+ input_info = [
750
+ main_input,
751
+ (
752
+ "cache_position",
753
+ [batch_size, query_length],
754
+ "int32",
755
+ ),
756
+ ]
757
+
758
+ if use_attention_mask:
759
+ input_info.extend(
760
+ [
761
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
762
+ ]
684
763
  )
685
764
 
686
- if rbln_max_seq_len is None:
687
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
688
- model_config, "n_positions", None
765
+ if query_length > 1:
766
+ input_info.extend(
767
+ [
768
+ ("query_position", [], "int16"),
769
+ ]
689
770
  )
690
- if rbln_max_seq_len is None:
691
- raise ValueError("`rbln_max_seq_len` should be specified.")
692
771
 
693
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
694
- rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
772
+ max_block_cnt = max_seq_len // kvcache_block_size
773
+
774
+ if query_length > 1:
775
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
776
+ else:
777
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
778
+
779
+ input_info.extend(
780
+ [
781
+ (
782
+ f"past_key_values_{i}",
783
+ [
784
+ kvcache_num_blocks,
785
+ num_key_value_heads,
786
+ kvcache_block_size,
787
+ head_dim,
788
+ ],
789
+ "float32",
790
+ )
791
+ for i in range(num_hidden_layers * 2)
792
+ ]
793
+ )
794
+
795
+ return input_info
695
796
 
696
- rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
697
- rbln_attn_impl=rbln_attn_impl,
698
- rbln_kvcache_partition_len=rbln_kvcache_partition_len,
699
- rbln_kvcache_block_size=rbln_kvcache_block_size,
700
- rbln_max_seq_len=rbln_max_seq_len,
797
+ @classmethod
798
+ def _update_rbln_config(
799
+ cls,
800
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
801
+ model: Optional["PreTrainedModel"] = None,
802
+ model_config: Optional["PretrainedConfig"] = None,
803
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
804
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
805
+ if rbln_config.max_seq_len is None:
806
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
807
+ model_config, "n_positions", None
808
+ )
809
+ if rbln_config.max_seq_len is None:
810
+ raise ValueError("`max_seq_len` should be specified.")
811
+
812
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
813
+ attn_impl=rbln_config.attn_impl,
814
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
815
+ kvcache_block_size=rbln_config.kvcache_block_size,
816
+ max_seq_len=rbln_config.max_seq_len,
701
817
  )
702
818
 
703
- if rbln_kvcache_block_size is None:
704
- if rbln_attn_impl == "eager":
705
- rbln_kvcache_block_size = rbln_max_seq_len
706
- else:
707
- rbln_kvcache_block_size = rbln_kvcache_partition_len
819
+ validate_attention_method(
820
+ attn_impl=rbln_config.attn_impl,
821
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
822
+ kvcache_block_size=rbln_config.kvcache_block_size,
823
+ max_seq_len=rbln_config.max_seq_len,
824
+ )
825
+
826
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
827
+ max_num_blocks = required_num_blocks
708
828
 
709
- rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
710
- if rbln_attn_impl == "flash_attn":
711
- max_num_blocks = cls.get_maximum_num_blocks(
829
+ if rbln_config.attn_impl == "flash_attn":
830
+ estimated_max_num_blocks = cls.get_maximum_num_blocks(
712
831
  config=model_config,
713
- tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
714
- kvcache_block_size=rbln_kvcache_block_size,
715
- nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
716
- n_model_params=rbln_kwargs["n_model_params"],
832
+ tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
833
+ kvcache_block_size=rbln_config.kvcache_block_size,
834
+ nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
835
+ n_model_params=sum(p.numel() for p in model.parameters()),
717
836
  )
718
- rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
719
837
 
720
- required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
721
- if rbln_kvcache_num_blocks < required_blocks:
722
- rbln_kvcache_num_blocks = required_blocks
838
+ max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
723
839
 
724
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
840
+ flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
841
+ if max_num_blocks < flash_min_blocks:
842
+ max_num_blocks = flash_min_blocks
725
843
 
726
- if rbln_kvcache_num_blocks < rbln_batch_size:
844
+ if max_num_blocks < rbln_config.batch_size:
727
845
  raise RuntimeError(
728
- f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
846
+ f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
729
847
  "Ensure the number of blocks is at least equal to the batch size."
730
848
  )
731
849
 
850
+ if rbln_config.kvcache_num_blocks is None:
851
+ rbln_config.kvcache_num_blocks = max_num_blocks
852
+ elif rbln_config.kvcache_num_blocks > max_num_blocks:
853
+ logger.warning(
854
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
855
+ f" than the estimated maximum number of blocks ({max_num_blocks})."
856
+ "This can cause a failure during model compilation."
857
+ )
858
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
732
859
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
733
860
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
734
861
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
735
- head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
736
862
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
863
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
737
864
 
738
- def get_input_info(
739
- batch_size,
740
- query_length,
741
- use_inputs_embeds,
742
- hidden_size,
743
- ):
744
- if use_inputs_embeds:
745
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
746
- else:
747
- main_input = ("input_ids", [batch_size, query_length], "int64")
748
-
749
- input_info = [
750
- main_input,
751
- (
752
- "cache_position",
753
- [batch_size, query_length],
754
- "int32",
755
- ),
756
- ]
757
-
758
- if rbln_use_attention_mask:
759
- input_info.extend(
760
- [
761
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
762
- ]
763
- )
764
-
765
- if query_length > 1:
766
- input_info.extend(
767
- [
768
- ("query_position", [], "int16"),
769
- ]
770
- )
771
-
772
- max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
773
-
774
- if query_length > 1:
775
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
776
- else:
777
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
778
-
779
- input_info.extend(
780
- [
781
- (
782
- f"past_key_values_{i}",
783
- [
784
- rbln_kvcache_num_blocks,
785
- num_key_value_heads,
786
- rbln_kvcache_block_size,
787
- head_dim,
788
- ],
789
- "float32",
790
- )
791
- for i in range(num_hidden_layers * 2)
792
- ]
793
- )
794
-
795
- return input_info
796
-
797
- prefill_input_info = get_input_info(
865
+ prefill_input_info = cls.get_input_info(
798
866
  batch_size=1,
799
- query_length=rbln_prefill_chunk_size,
800
- use_inputs_embeds=rbln_use_inputs_embeds,
867
+ query_length=rbln_config.prefill_chunk_size,
868
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
869
+ use_attention_mask=rbln_config.use_attention_mask,
870
+ max_seq_len=rbln_config.max_seq_len,
871
+ kvcache_block_size=rbln_config.kvcache_block_size,
872
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
873
+ num_key_value_heads=num_key_value_heads,
874
+ num_hidden_layers=num_hidden_layers,
801
875
  hidden_size=hidden_size,
876
+ head_dim=head_dim,
802
877
  )
803
- dec_input_info = get_input_info(
804
- batch_size=rbln_batch_size,
878
+ dec_input_info = cls.get_input_info(
879
+ batch_size=rbln_config.batch_size,
805
880
  query_length=1,
806
- use_inputs_embeds=rbln_use_inputs_embeds,
881
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
882
+ use_attention_mask=rbln_config.use_attention_mask,
883
+ max_seq_len=rbln_config.max_seq_len,
884
+ kvcache_block_size=rbln_config.kvcache_block_size,
885
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
886
+ num_key_value_heads=num_key_value_heads,
887
+ num_hidden_layers=num_hidden_layers,
807
888
  hidden_size=hidden_size,
889
+ head_dim=head_dim,
808
890
  )
809
891
 
810
892
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
811
893
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
812
894
 
813
- rbln_config = RBLNConfig(
814
- rbln_cls=cls.__name__,
815
- compile_cfgs=[prefill_compile_config, dec_compile_config],
816
- rbln_kwargs=rbln_kwargs,
817
- )
818
-
819
- rbln_config.model_cfg.update(
820
- {
821
- "max_seq_len": rbln_max_seq_len,
822
- "batch_size": rbln_batch_size,
823
- "prefill_chunk_size": rbln_prefill_chunk_size,
824
- "use_attention_mask": rbln_use_attention_mask,
825
- "use_inputs_embeds": rbln_use_inputs_embeds,
826
- "kvcache_partition_len": rbln_kvcache_partition_len,
827
- "kvcache_block_size": rbln_kvcache_block_size,
828
- "attn_impl": rbln_attn_impl,
829
- "kvcache_num_blocks": rbln_kvcache_num_blocks,
830
- }
831
- )
832
-
833
- if rbln_quantization is not None:
834
- rbln_config.model_cfg.update({"quantization": rbln_quantization})
895
+ rbln_config.set_compile_cfgs([prefill_compile_config, dec_compile_config])
835
896
 
836
897
  return rbln_config
837
898
 
@@ -839,18 +900,23 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
839
900
  def _create_runtimes(
840
901
  cls,
841
902
  compiled_models: List[rebel.RBLNCompiledModel],
842
- rbln_device_map: Dict[str, int],
843
- activate_profiler: Optional[bool] = None,
903
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
844
904
  ) -> List[rebel.Runtime]:
845
- if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
905
+ if any(model_name not in rbln_config.device_map for model_name in ["prefill", "decoder"]):
846
906
  cls._raise_missing_compiled_file_error(["prefill", "decoder"])
847
907
 
848
908
  return [
849
- compiled_models[0].create_runtime(
850
- tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
909
+ rebel.Runtime(
910
+ compiled_models[0],
911
+ tensor_type="pt",
912
+ device=rbln_config.device_map["prefill"],
913
+ activate_profiler=rbln_config.activate_profiler,
851
914
  ),
852
- compiled_models[1].create_runtime(
853
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
915
+ rebel.Runtime(
916
+ compiled_models[1],
917
+ tensor_type="pt",
918
+ device=rbln_config.device_map["decoder"],
919
+ activate_profiler=rbln_config.activate_profiler,
854
920
  ),
855
921
  ]
856
922
 
@@ -887,11 +953,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
887
953
  model_inputs.update({"input_ids": input_ids})
888
954
 
889
955
  if inputs_embeds is not None:
890
- if self.rbln_config.model_cfg["use_inputs_embeds"]:
956
+ if self.rbln_config.use_inputs_embeds:
891
957
  model_inputs.update({"inputs_embeds": inputs_embeds})
892
958
  else:
893
959
  raise ValueError(
894
- "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
960
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
895
961
  )
896
962
  else:
897
963
  model_inputs.update({"input_ids": input_ids})