optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -26,13 +26,15 @@ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, Pre
26
26
  from transformers.modeling_utils import no_init_weights
27
27
  from transformers.utils import ModelOutput
28
28
 
29
+ from ....configuration_utils import RBLNCompileConfig
29
30
  from ....modeling import RBLNModel
30
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
33
  from ...utils.rbln_quantization import QuantizationManager
34
+ from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
34
35
  from .decoderonly_architecture import (
35
36
  DecoderOnlyWrapper,
37
+ set_default_values,
36
38
  validate_attention_method,
37
39
  )
38
40
 
@@ -161,6 +163,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
161
163
  attention_mask: Optional[torch.Tensor] = None,
162
164
  batch_idx: Optional[int] = None,
163
165
  block_tables: Optional[torch.Tensor] = None,
166
+ position_embed: Optional[torch.Tensor] = None,
164
167
  ):
165
168
  if input_ids is None and inputs_embeds is None:
166
169
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -185,9 +188,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
185
188
  block_tables,
186
189
  is_external_block_tables,
187
190
  attention_mask=attention_mask,
191
+ position_embed=position_embed,
188
192
  )
189
193
  else:
190
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
194
+ return self.prefill_forward(
195
+ inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
196
+ )
191
197
 
192
198
  def decode_forward(
193
199
  self,
@@ -196,6 +202,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
196
202
  block_tables: torch.Tensor = None,
197
203
  is_external_block_tables: bool = None,
198
204
  attention_mask: Optional[torch.Tensor] = None,
205
+ position_embed: Optional[torch.Tensor] = None,
199
206
  ) -> torch.FloatTensor:
200
207
  batch_size = inputs.shape[0]
201
208
  if batch_size != self.batch_size:
@@ -222,13 +229,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
222
229
 
223
230
  attention_mask = self.dec_attn_mask
224
231
 
225
- attention_mask = self.dec_attn_mask
226
-
227
232
  logits = super().forward(
228
233
  inputs,
229
234
  cache_position,
230
235
  attention_mask if self.use_attention_mask else None,
231
236
  block_tables,
237
+ position_embed,
232
238
  )
233
239
 
234
240
  return logits
@@ -241,6 +247,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
241
247
  batch_idx: int = None,
242
248
  block_tables: torch.Tensor = None,
243
249
  is_external_block_tables: bool = None,
250
+ position_embed: Optional[torch.Tensor] = None,
244
251
  ) -> torch.FloatTensor:
245
252
  """
246
253
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -251,6 +258,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
251
258
  # Handle continuous batching in a compiled graph by extracting valid inputs
252
259
  # If an attention mask is provided, select only the valid (non-masked) inputs
253
260
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
261
+ if position_embed is not None:
262
+ position_embed = (
263
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
264
+ )
254
265
 
255
266
  query_length = inputs.shape[1]
256
267
  if query_length > self.max_seq_len:
@@ -295,9 +306,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
295
306
  dim=-1,
296
307
  )
297
308
 
309
+ if position_embed is not None:
310
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
311
+
298
312
  # Extract the current chunk of inputs and cache positions
299
313
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
300
314
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
315
+ if position_embed is not None:
316
+ position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
301
317
 
302
318
  if self.use_attention_mask:
303
319
  # Update attention mask to ensure proper causal behavior
@@ -315,6 +331,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
315
331
  chunked_attention_mask if self.use_attention_mask else None,
316
332
  query_position,
317
333
  block_tables,
334
+ position_embed_chunk if position_embed is not None else None,
318
335
  out=out_buffers,
319
336
  )
320
337
 
@@ -358,17 +375,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
358
375
  _use_rotary_emb = True
359
376
 
360
377
  def __post_init__(self, **kwargs):
361
- self.batch_size = self.rbln_config.model_cfg["batch_size"]
362
- self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
363
- self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
364
- self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
365
- # FIXME get kvcache_num_blocks from compiled results.
366
- self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
367
- self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
368
- attn_impl = self.rbln_config.model_cfg["attn_impl"]
369
378
  main_input_name = self.main_input_name
370
379
 
371
- if self.rbln_config.model_cfg["use_inputs_embeds"]:
380
+ if self.rbln_config.use_inputs_embeds:
372
381
  main_input_name = "inputs_embeds"
373
382
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
374
383
  with no_init_weights():
@@ -382,40 +391,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
382
391
  self.embed_tokens = None
383
392
 
384
393
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
385
- dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
394
+ dec_attn_mask = torch.zeros(
395
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
396
+ )
386
397
  block_tables = torch.zeros(
387
- self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
398
+ self.rbln_config.batch_size,
399
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
400
+ dtype=torch.int16,
388
401
  ).fill_(-1)
389
- free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
402
+ free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
390
403
 
391
404
  self.prefill_decoder = RBLNRuntimeModel(
392
405
  runtime=self.model[0],
393
406
  main_input_name=main_input_name,
394
407
  embed_tokens=self.embed_tokens,
395
408
  phase="prefill",
396
- batch_size=self.batch_size,
409
+ batch_size=self.rbln_config.batch_size,
397
410
  dec_attn_mask=dec_attn_mask,
398
411
  block_tables=block_tables,
399
412
  free_block_pool=free_block_pool,
400
- kvcache_block_size=self.kvcache_block_size,
413
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
401
414
  vocab_size=self.config.vocab_size,
402
- prefill_chunk_size=self.prefill_chunk_size,
403
- max_seq_len=self.max_seq_len,
404
- use_attention_mask=self.use_attention_mask,
405
- attn_impl=attn_impl,
415
+ prefill_chunk_size=self.rbln_config.prefill_chunk_size,
416
+ max_seq_len=self.rbln_config.max_seq_len,
417
+ use_attention_mask=self.rbln_config.use_attention_mask,
418
+ attn_impl=self.rbln_config.attn_impl,
406
419
  )
407
420
  self.decoder = RBLNRuntimeModel(
408
421
  runtime=self.model[1],
409
422
  main_input_name=main_input_name,
410
423
  embed_tokens=self.embed_tokens,
411
424
  phase="decode",
412
- batch_size=self.batch_size,
425
+ batch_size=self.rbln_config.batch_size,
413
426
  dec_attn_mask=dec_attn_mask,
414
427
  block_tables=block_tables,
415
428
  free_block_pool=free_block_pool,
416
- kvcache_block_size=self.kvcache_block_size,
417
- use_attention_mask=self.use_attention_mask,
418
- attn_impl=attn_impl,
429
+ kvcache_block_size=self.rbln_config.kvcache_block_size,
430
+ use_attention_mask=self.rbln_config.use_attention_mask,
431
+ attn_impl=self.rbln_config.attn_impl,
419
432
  )
420
433
 
421
434
  @classmethod
@@ -424,13 +437,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
424
437
  model: "PreTrainedModel",
425
438
  save_dir_path: Path,
426
439
  subfolder: str,
427
- rbln_config: RBLNConfig,
440
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
428
441
  ):
429
442
  """
430
443
  If you are unavoidably running on a CPU rather than an RBLN device,
431
444
  store the torch tensor, weight, etc. in this function.
432
445
  """
433
- if rbln_config.model_cfg["use_inputs_embeds"]:
446
+ if rbln_config.use_inputs_embeds:
434
447
  save_dict = {}
435
448
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
436
449
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -438,6 +451,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
438
451
  def get_input_embeddings(self):
439
452
  return self.embed_tokens
440
453
 
454
+ def get_attn_impl(self) -> str:
455
+ return self.rbln_config.attn_impl
456
+
457
+ def get_kvcache_num_blocks(self) -> int:
458
+ return self.rbln_config.kvcache_num_blocks
459
+
441
460
  @classmethod
442
461
  def get_quantized_model(
443
462
  cls,
@@ -495,33 +514,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
495
514
  return val
496
515
 
497
516
  @classmethod
498
- def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
499
- logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
500
-
501
- rbln_kwargs = kwargs.get("rbln_kwargs", {})
502
- rbln_quantization = rbln_kwargs.get("quantization", None)
503
- if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
517
+ def get_pytorch_model(
518
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
519
+ ) -> "PreTrainedModel":
520
+ if (
521
+ rbln_config is not None
522
+ and "format" in rbln_config.quantization
523
+ and rbln_config.quantization["format"] == "rbln"
524
+ ):
504
525
  model = cls.get_quantized_model(*args, **kwargs)
505
526
  else:
506
527
  model = super().get_pytorch_model(*args, **kwargs)
507
528
 
508
- logger.debug("Loaded the LLM model to the CPU.")
509
529
  return model
510
530
 
511
531
  @classmethod
512
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
513
- wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
514
- wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
515
- wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
516
- wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
517
- wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
518
- wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
519
-
532
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
533
+ wrapper_cfg = {
534
+ "max_seq_len": rbln_config.max_seq_len,
535
+ "attn_impl": rbln_config.attn_impl,
536
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
537
+ "kvcache_block_size": rbln_config.kvcache_block_size,
538
+ "use_rotary_emb": cls._use_rotary_emb,
539
+ "use_attention_mask": rbln_config.use_attention_mask,
540
+ }
520
541
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
521
542
 
522
543
  @classmethod
523
544
  @torch.inference_mode()
524
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
545
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
525
546
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
526
547
 
527
548
  rbln_compile_configs = rbln_config.compile_cfgs
@@ -543,28 +564,81 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
543
564
 
544
565
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
545
566
 
546
- quantize_config = rbln_config.model_cfg.get("quantization", None)
547
-
548
567
  @QuantizationManager.with_quantization_env
549
568
  def compile_model(*args, **kwargs):
550
- wrapped_model.phase = "prefill"
551
- compiled_prefill = RBLNModel.compile(
552
- wrapped_model,
553
- prefill_compile_config,
554
- example_inputs=prefill_example_inputs,
555
- compile_context=context,
556
- )
569
+ try:
570
+ original_linear = torch.nn.functional.linear
571
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
572
+ wrapped_model.phase = "prefill"
573
+ compiled_prefill = RBLNModel.compile(
574
+ wrapped_model,
575
+ prefill_compile_config,
576
+ example_inputs=prefill_example_inputs,
577
+ compile_context=context,
578
+ )
557
579
 
558
- wrapped_model.phase = "decode"
559
- compiled_decoder = RBLNModel.compile(
560
- wrapped_model,
561
- dec_compile_config,
562
- example_inputs=dec_example_inputs,
563
- compile_context=context,
580
+ wrapped_model.phase = "decode"
581
+ compiled_decoder = RBLNModel.compile(
582
+ wrapped_model,
583
+ dec_compile_config,
584
+ example_inputs=dec_example_inputs,
585
+ compile_context=context,
586
+ )
587
+ return {"prefill": compiled_prefill, "decoder": compiled_decoder}
588
+ finally:
589
+ torch.nn.functional.linear = original_linear
590
+
591
+ compiled_models = compile_model(quantize_config=rbln_config.quantization)
592
+
593
+ # check if the memory is enough to have additional blocks
594
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
595
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
596
+ cls.maybe_suggest_kvcache_num_blocks(
597
+ compiled_models=compiled_models,
598
+ model_config=model.config,
599
+ rbln_config=rbln_config,
564
600
  )
565
- return {"prefill": compiled_prefill, "decoder": compiled_decoder}
566
601
 
567
- return compile_model(quantize_config=quantize_config)
602
+ return compiled_models
603
+
604
+ @classmethod
605
+ def maybe_suggest_kvcache_num_blocks(
606
+ cls,
607
+ compiled_models: Dict[str, rebel.RBLNCompiledModel],
608
+ model_config: PretrainedConfig,
609
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
610
+ ) -> None:
611
+ # Get the actual memory allocation of each node by key
612
+ alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
613
+ alloc_memory_by_key: Dict[str, int] = {
614
+ key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
615
+ }
616
+ for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
617
+ alloc_memory_by_key[key] += sum(memory_per_node)
618
+ alloc_memory_by_key.pop("PortRecur") # kv-cache
619
+ kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
620
+
621
+ # Get the maximum number of blocks that can be allocated
622
+ buffer = sum(alloc_memory_by_key.values())
623
+ max_num_blocks = cls.get_maximum_num_blocks(
624
+ config=model_config,
625
+ tensor_parallel_size=rbln_config.tensor_parallel_size,
626
+ kvcache_block_size=rbln_config.kvcache_block_size,
627
+ kernel_size=kernel_size,
628
+ buffer=buffer,
629
+ )
630
+
631
+ # Since our estimation logic is not always accurate,
632
+ # users can set `kvcache_num_blocks` to `max_num_blocks`.
633
+ # If the memory is not enough, the model will fail to compile.
634
+ if rbln_config.kvcache_num_blocks < max_num_blocks:
635
+ logger.warning(
636
+ f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
637
+ "Our analysis indicates that additional memory is available for more blocks. "
638
+ f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
639
+ "Please be advised that our memory estimation algorithm has limitations, "
640
+ "and increasing this value may not guarantee successful model compilation."
641
+ )
568
642
 
569
643
  @classmethod
570
644
  def get_maximum_num_blocks(
@@ -572,14 +646,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
572
646
  config: PretrainedConfig,
573
647
  tensor_parallel_size: int,
574
648
  kvcache_block_size: int,
575
- nbits_per_param: int,
576
- n_model_params: int,
649
+ nbits_per_param: Optional[int] = None,
650
+ n_model_params: Optional[int] = None,
651
+ kernel_size: Optional[int] = None,
652
+ buffer: Optional[int] = None,
577
653
  ) -> int:
654
+ """
655
+ We are finding max_n_blocks(x) that satisfies the following equation:
656
+
657
+ available_dram - kernel_size - buffer
658
+ - num_layers * 2 * tensor_parallel_size
659
+ * align_2MB(
660
+ x
661
+ * block_size
662
+ * align_64(head_dim)
663
+ * math.ceil(num_key_value_heads / tensor_parallel_size)
664
+ * 2
665
+ ) > 0
666
+
667
+ This inequality can be rewritten as follows:
668
+
669
+ a - c * align_2MB(b * x) > 0
670
+ where
671
+ a = available_dram - kernel_size - buffer
672
+ b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
673
+ c = num_layers * 2 * tensor_parallel_size
674
+
675
+ We can rewrite the inequality as follows:
676
+ k > align_2MB(b*x)
677
+ where
678
+ k = a / c
679
+
680
+ After that, we can derive the following equation:
681
+ x = floor(2**21 / b * floor((k - 1) / 2**21))
682
+ """
683
+
578
684
  def align(x: int, nbytes: int) -> int:
579
685
  return int(math.ceil(x / nbytes) * nbytes)
580
686
 
581
687
  def align_2MB(x: int) -> int:
582
- return align(x, 2 * 1024 * 1024)
688
+ return align(x, 2**21)
583
689
 
584
690
  num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
585
691
  num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
@@ -593,223 +699,206 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
593
699
  ATOM_SYS_DRAM_NBYTES = 288 * 2**20
594
700
  available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
595
701
 
596
- # Get estimated kernel size (approximated)
597
- lm_heads_params = align(vocab_size, 64) * hidden_size
598
- lm_heads_nbytes = (
599
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
600
- )
601
- params = n_model_params - lm_heads_params
602
- layer_nbytes = (
603
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
604
- * num_layers
605
- * tensor_parallel_size
606
- )
607
- kernel_size = layer_nbytes + lm_heads_nbytes
702
+ if kernel_size is None:
703
+ if n_model_params is None:
704
+ raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
705
+ # Get estimated kernel size (approximated)
706
+ lm_heads_params = align(vocab_size, 64) * hidden_size
707
+ lm_heads_nbytes = (
708
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
709
+ )
710
+ params = n_model_params - lm_heads_params
711
+ layer_nbytes = (
712
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
713
+ * num_layers
714
+ * tensor_parallel_size
715
+ )
716
+ kernel_size = layer_nbytes + lm_heads_nbytes
717
+ elif n_model_params is not None:
718
+ raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
608
719
 
609
720
  available_dram -= kernel_size
610
721
 
611
- # TODO: Accurate buffer estimation
612
- buffer = 2**30 # 1GB Buffer
613
- if tensor_parallel_size <= 4:
614
- buffer /= 4
615
-
722
+ if buffer is None:
723
+ # TODO: Accurate buffer estimation
724
+ buffer_per_core = 2**29 # 500MB per npu
725
+ buffer = buffer_per_core * tensor_parallel_size
616
726
  available_dram -= buffer
617
727
 
618
- # Estimate nbytes per a single kvcache block
619
- nbytes_per_block = (
620
- align_2MB(
621
- kvcache_block_size
622
- * head_dim
623
- * math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
624
- * 2 # (fp16)
625
- )
626
- * num_layers
627
- * 2 # (k, v)
628
- * tensor_parallel_size
629
- )
630
- n_blocks = available_dram // nbytes_per_block
728
+ b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
729
+ c = num_layers * 2 * tensor_parallel_size
730
+ k = available_dram / c
731
+ max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
631
732
 
632
- return n_blocks, nbytes_per_block
733
+ return max_n_blocks
633
734
 
634
735
  @classmethod
635
- def _get_rbln_config(
736
+ def get_input_info(
636
737
  cls,
637
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
638
- model_config: "PretrainedConfig",
639
- rbln_kwargs: Dict[str, Any] = {},
640
- ) -> RBLNConfig:
641
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
642
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
643
- rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
644
- rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
645
- rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
646
- rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
647
- rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
648
- rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
649
- rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
650
-
651
- if rbln_use_attention_mask is None:
652
- rbln_use_attention_mask = False
653
- rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
654
- if rbln_npu == "RBLN-CA02":
655
- rbln_use_attention_mask = True
656
-
657
- if rbln_prefill_chunk_size is None:
658
- rbln_prefill_chunk_size = 128
659
- elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
660
- raise ValueError(
661
- f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
738
+ batch_size: int,
739
+ query_length: int,
740
+ use_inputs_embeds: bool,
741
+ use_attention_mask: bool,
742
+ max_seq_len: int,
743
+ kvcache_block_size: int,
744
+ kvcache_num_blocks: int,
745
+ num_key_value_heads: int,
746
+ num_hidden_layers: int,
747
+ hidden_size: int,
748
+ head_dim: int,
749
+ ):
750
+ if use_inputs_embeds:
751
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
752
+ else:
753
+ main_input = ("input_ids", [batch_size, query_length], "int64")
754
+
755
+ input_info = [
756
+ main_input,
757
+ (
758
+ "cache_position",
759
+ [batch_size, query_length],
760
+ "int32",
761
+ ),
762
+ ]
763
+
764
+ if use_attention_mask:
765
+ input_info.extend(
766
+ [
767
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
768
+ ]
662
769
  )
663
770
 
664
- if rbln_max_seq_len is None:
665
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
666
- model_config, "n_positions", None
771
+ if query_length > 1:
772
+ input_info.extend(
773
+ [
774
+ ("query_position", [], "int16"),
775
+ ]
667
776
  )
668
- if rbln_max_seq_len is None:
669
- raise ValueError("`rbln_max_seq_len` should be specified.")
670
777
 
671
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
672
- rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
778
+ max_block_cnt = max_seq_len // kvcache_block_size
779
+
780
+ if query_length > 1:
781
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
782
+ else:
783
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
673
784
 
674
- rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
675
- rbln_attn_impl=rbln_attn_impl,
676
- rbln_kvcache_partition_len=rbln_kvcache_partition_len,
677
- rbln_kvcache_block_size=rbln_kvcache_block_size,
678
- rbln_max_seq_len=rbln_max_seq_len,
785
+ input_info.extend(
786
+ [
787
+ (
788
+ f"past_key_values_{i}",
789
+ [
790
+ kvcache_num_blocks,
791
+ num_key_value_heads,
792
+ kvcache_block_size,
793
+ head_dim,
794
+ ],
795
+ "float32",
796
+ )
797
+ for i in range(num_hidden_layers * 2)
798
+ ]
679
799
  )
680
800
 
681
- if rbln_kvcache_block_size is None:
682
- if rbln_attn_impl == "eager":
683
- rbln_kvcache_block_size = rbln_max_seq_len
684
- else:
685
- rbln_kvcache_block_size = rbln_kvcache_partition_len
801
+ return input_info
802
+
803
+ @classmethod
804
+ def _update_rbln_config(
805
+ cls,
806
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
807
+ model: Optional["PreTrainedModel"] = None,
808
+ model_config: Optional["PretrainedConfig"] = None,
809
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
810
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
811
+ if rbln_config.max_seq_len is None:
812
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
813
+ model_config, "n_positions", None
814
+ )
815
+ if rbln_config.max_seq_len is None:
816
+ raise ValueError("`max_seq_len` should be specified.")
817
+
818
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
819
+ attn_impl=rbln_config.attn_impl,
820
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
821
+ kvcache_block_size=rbln_config.kvcache_block_size,
822
+ max_seq_len=rbln_config.max_seq_len,
823
+ )
824
+
825
+ validate_attention_method(
826
+ attn_impl=rbln_config.attn_impl,
827
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
828
+ kvcache_block_size=rbln_config.kvcache_block_size,
829
+ max_seq_len=rbln_config.max_seq_len,
830
+ )
831
+
832
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
833
+ max_num_blocks = required_num_blocks
686
834
 
687
- rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
688
- if rbln_attn_impl == "flash_attn":
689
- max_num_blocks, _ = cls.get_maximum_num_blocks(
835
+ if rbln_config.attn_impl == "flash_attn":
836
+ estimated_max_num_blocks = cls.get_maximum_num_blocks(
690
837
  config=model_config,
691
- tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
692
- kvcache_block_size=rbln_kvcache_block_size,
693
- nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
694
- n_model_params=rbln_kwargs["n_model_params"],
838
+ tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
839
+ kvcache_block_size=rbln_config.kvcache_block_size,
840
+ nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
841
+ n_model_params=sum(p.numel() for p in model.parameters()),
695
842
  )
696
- rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
697
843
 
698
- required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
699
- if rbln_kvcache_num_blocks < required_blocks:
700
- rbln_kvcache_num_blocks = required_blocks
844
+ max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
701
845
 
702
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
846
+ flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
847
+ if max_num_blocks < flash_min_blocks:
848
+ max_num_blocks = flash_min_blocks
703
849
 
704
- if rbln_kvcache_num_blocks < rbln_batch_size:
850
+ if max_num_blocks < rbln_config.batch_size:
705
851
  raise RuntimeError(
706
- f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
852
+ f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
707
853
  "Ensure the number of blocks is at least equal to the batch size."
708
854
  )
709
855
 
856
+ if rbln_config.kvcache_num_blocks is None:
857
+ rbln_config.kvcache_num_blocks = max_num_blocks
858
+ elif rbln_config.kvcache_num_blocks > max_num_blocks:
859
+ logger.warning(
860
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
861
+ f" than the estimated maximum number of blocks ({max_num_blocks})."
862
+ "This can cause a failure during model compilation."
863
+ )
864
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
710
865
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
711
866
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
712
867
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
713
- head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
714
868
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
869
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
715
870
 
716
- def get_input_info(
717
- batch_size,
718
- query_length,
719
- use_inputs_embeds,
720
- hidden_size,
721
- ):
722
- if use_inputs_embeds:
723
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
724
- else:
725
- main_input = ("input_ids", [batch_size, query_length], "int64")
726
-
727
- input_info = [
728
- main_input,
729
- (
730
- "cache_position",
731
- [batch_size, query_length],
732
- "int32",
733
- ),
734
- ]
735
-
736
- if rbln_use_attention_mask:
737
- input_info.extend(
738
- [
739
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
740
- ]
741
- )
742
-
743
- if query_length > 1:
744
- input_info.extend(
745
- [
746
- ("query_position", [], "int16"),
747
- ]
748
- )
749
-
750
- max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
751
-
752
- if query_length > 1:
753
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
754
- else:
755
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
756
-
757
- input_info.extend(
758
- [
759
- (
760
- f"past_key_values_{i}",
761
- [
762
- rbln_kvcache_num_blocks,
763
- num_key_value_heads,
764
- rbln_kvcache_block_size,
765
- head_dim,
766
- ],
767
- "float32",
768
- )
769
- for i in range(num_hidden_layers * 2)
770
- ]
771
- )
772
-
773
- return input_info
774
-
775
- prefill_input_info = get_input_info(
871
+ prefill_input_info = cls.get_input_info(
776
872
  batch_size=1,
777
- query_length=rbln_prefill_chunk_size,
778
- use_inputs_embeds=rbln_use_inputs_embeds,
873
+ query_length=rbln_config.prefill_chunk_size,
874
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
875
+ use_attention_mask=rbln_config.use_attention_mask,
876
+ max_seq_len=rbln_config.max_seq_len,
877
+ kvcache_block_size=rbln_config.kvcache_block_size,
878
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
879
+ num_key_value_heads=num_key_value_heads,
880
+ num_hidden_layers=num_hidden_layers,
779
881
  hidden_size=hidden_size,
882
+ head_dim=head_dim,
780
883
  )
781
- dec_input_info = get_input_info(
782
- batch_size=rbln_batch_size,
884
+ dec_input_info = cls.get_input_info(
885
+ batch_size=rbln_config.batch_size,
783
886
  query_length=1,
784
- use_inputs_embeds=rbln_use_inputs_embeds,
887
+ use_inputs_embeds=rbln_config.use_inputs_embeds,
888
+ use_attention_mask=rbln_config.use_attention_mask,
889
+ max_seq_len=rbln_config.max_seq_len,
890
+ kvcache_block_size=rbln_config.kvcache_block_size,
891
+ kvcache_num_blocks=rbln_config.kvcache_num_blocks,
892
+ num_key_value_heads=num_key_value_heads,
893
+ num_hidden_layers=num_hidden_layers,
785
894
  hidden_size=hidden_size,
895
+ head_dim=head_dim,
786
896
  )
787
897
 
788
898
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
789
899
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
790
900
 
791
- rbln_config = RBLNConfig(
792
- rbln_cls=cls.__name__,
793
- compile_cfgs=[prefill_compile_config, dec_compile_config],
794
- rbln_kwargs=rbln_kwargs,
795
- )
796
-
797
- rbln_config.model_cfg.update(
798
- {
799
- "max_seq_len": rbln_max_seq_len,
800
- "batch_size": rbln_batch_size,
801
- "prefill_chunk_size": rbln_prefill_chunk_size,
802
- "use_attention_mask": rbln_use_attention_mask,
803
- "use_inputs_embeds": rbln_use_inputs_embeds,
804
- "kvcache_partition_len": rbln_kvcache_partition_len,
805
- "kvcache_block_size": rbln_kvcache_block_size,
806
- "attn_impl": rbln_attn_impl,
807
- "kvcache_num_blocks": rbln_kvcache_num_blocks,
808
- }
809
- )
810
-
811
- if rbln_quantization is not None:
812
- rbln_config.model_cfg.update({"quantization": rbln_quantization})
901
+ rbln_config.set_compile_cfgs([prefill_compile_config, dec_compile_config])
813
902
 
814
903
  return rbln_config
815
904
 
@@ -817,18 +906,23 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
817
906
  def _create_runtimes(
818
907
  cls,
819
908
  compiled_models: List[rebel.RBLNCompiledModel],
820
- rbln_device_map: Dict[str, int],
821
- activate_profiler: Optional[bool] = None,
909
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
822
910
  ) -> List[rebel.Runtime]:
823
- if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
911
+ if any(model_name not in rbln_config.device_map for model_name in ["prefill", "decoder"]):
824
912
  cls._raise_missing_compiled_file_error(["prefill", "decoder"])
825
913
 
826
914
  return [
827
- compiled_models[0].create_runtime(
828
- tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
915
+ rebel.Runtime(
916
+ compiled_models[0],
917
+ tensor_type="pt",
918
+ device=rbln_config.device_map["prefill"],
919
+ activate_profiler=rbln_config.activate_profiler,
829
920
  ),
830
- compiled_models[1].create_runtime(
831
- tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
921
+ rebel.Runtime(
922
+ compiled_models[1],
923
+ tensor_type="pt",
924
+ device=rbln_config.device_map["decoder"],
925
+ activate_profiler=rbln_config.activate_profiler,
832
926
  ),
833
927
  ]
834
928
 
@@ -865,11 +959,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
865
959
  model_inputs.update({"input_ids": input_ids})
866
960
 
867
961
  if inputs_embeds is not None:
868
- if self.rbln_config.model_cfg["use_inputs_embeds"]:
962
+ if self.rbln_config.use_inputs_embeds:
869
963
  model_inputs.update({"inputs_embeds": inputs_embeds})
870
964
  else:
871
965
  raise ValueError(
872
- "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
966
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
873
967
  )
874
968
  else:
875
969
  model_inputs.update({"input_ids": input_ids})