optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -20,45 +20,34 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
- import functools
24
- import glob
23
+
25
24
  import inspect
26
- import os
27
25
  from dataclasses import dataclass
28
26
  from pathlib import Path
29
27
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
30
28
 
31
29
  import rebel
32
30
  import torch
33
- import transformers
34
- from safetensors.torch import load_file
31
+ from rebel.compile_context import CompileContext
35
32
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
36
33
  from transformers.modeling_utils import no_init_weights
37
34
  from transformers.utils import ModelOutput
38
35
 
39
- from ....modeling_base import RBLNModel
40
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ....modeling import RBLNModel
37
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
41
38
  from ....utils.logging import get_logger
42
39
  from ....utils.runtime_utils import RBLNPytorchRuntime
43
- from ....utils.timer_utils import rbln_timer
44
- from .decoderonly_architecture import DecoderOnlyWrapper
40
+ from ...utils.rbln_quantization import QuantizationManager
41
+ from .decoderonly_architecture import (
42
+ DecoderOnlyWrapper,
43
+ validate_attention_method,
44
+ )
45
45
 
46
46
 
47
47
  logger = get_logger()
48
48
 
49
49
  if TYPE_CHECKING:
50
- from transformers import (
51
- AutoFeatureExtractor,
52
- AutoProcessor,
53
- AutoTokenizer,
54
- PretrainedConfig,
55
- )
56
-
57
- SUPPORTED_QUANTIZATIONS = {
58
- "rbln": [
59
- "w4a16",
60
- ],
61
- }
50
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
62
51
 
63
52
 
64
53
  class RBLNRuntimeModel(RBLNPytorchRuntime):
@@ -70,32 +59,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
70
59
  inputs_embeds: torch.Tensor,
71
60
  attention_mask: torch.Tensor,
72
61
  cache_position: torch.Tensor,
73
- batch_position: torch.Tensor,
74
- query_idx: torch.Tensor,
75
62
  **kwargs,
76
63
  ):
77
64
  if inputs_embeds is None:
78
65
  inp = input_ids
79
66
  if self.embed_tokens is not None:
80
67
  inp = self.embed_tokens(inp)
81
-
82
- return super().forward(
83
- inp,
84
- attention_mask,
85
- cache_position,
86
- batch_position,
87
- query_idx,
88
- **kwargs,
89
- )
90
68
  else:
91
- return super().forward(
92
- inputs_embeds,
93
- attention_mask,
94
- cache_position,
95
- batch_position,
96
- query_idx,
97
- **kwargs,
98
- )
69
+ inp = inputs_embeds
70
+
71
+ return super().forward(
72
+ inp,
73
+ attention_mask,
74
+ cache_position,
75
+ **kwargs,
76
+ )
99
77
 
100
78
 
101
79
  @dataclass
@@ -127,24 +105,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
127
105
  main_input_name = "input_ids"
128
106
  auto_model_class = AutoModelForCausalLM
129
107
  _decoder_wrapper_cls = DecoderOnlyWrapper
130
- _original_cls = None
131
-
132
- @classmethod
133
- @property
134
- def original_cls(cls):
135
- """
136
- Lazily loads and caches the corresponding Hugging Face model class.
137
- Removes 'RBLN' prefix from the class name to get the original class name
138
- (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
139
- the transformers module.
140
-
141
- Returns:
142
- type: The original Hugging Face model class
143
- """
144
- if cls._original_cls is None:
145
- hf_original_cls_name = cls.__name__[4:]
146
- cls._original_cls = getattr(transformers, hf_original_cls_name)
147
- return cls._original_cls
108
+ _use_rotary_emb = True
148
109
 
149
110
  def __post_init__(self, **kwargs):
150
111
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -203,6 +164,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
203
164
  def get_quantized_model(
204
165
  cls,
205
166
  model_id: str,
167
+ config: Optional[PretrainedConfig] = None,
206
168
  use_auth_token: Optional[Union[bool, str]] = None,
207
169
  revision: Optional[str] = None,
208
170
  force_download: bool = False,
@@ -212,57 +174,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
212
174
  trust_remote_code: bool = False,
213
175
  **kwargs,
214
176
  ):
215
- from ...utils.rbln_quantization import update_layers_to_quantized
177
+ from ...utils.rbln_quantization import prepare_model_for_quantization
216
178
 
217
179
  kwargs = cls.update_kwargs(kwargs)
218
180
 
219
- config = AutoConfig.from_pretrained(
220
- model_id,
221
- use_auth_token=use_auth_token,
222
- revision=revision,
223
- force_download=force_download,
224
- cache_dir=cache_dir,
225
- trust_remote_code=trust_remote_code,
226
- **kwargs,
227
- )
181
+ if config is None:
182
+ config = AutoConfig.from_pretrained(
183
+ model_id,
184
+ use_auth_token=use_auth_token,
185
+ revision=revision,
186
+ force_download=force_download,
187
+ cache_dir=cache_dir,
188
+ trust_remote_code=trust_remote_code,
189
+ **kwargs,
190
+ )
228
191
 
229
192
  with no_init_weights():
230
193
  model = AutoModelForCausalLM.from_config(config)
231
194
 
232
- update_layers_to_quantized(model)
233
-
234
- n_layer = kwargs.get("num_hidden_layers", None)
235
- cls._load_weights_directly_to_model(model, model_id, n_layer)
195
+ prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
236
196
 
237
197
  return model
238
198
 
239
- def _load_weights_directly_to_model(model, model_id, n_layer=None):
240
- """
241
- Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
242
- """
243
-
244
- model_params = dict(model.named_parameters(recurse=True))
245
- model_buffers = dict(model.named_buffers(recurse=True))
246
- safetensor_files = glob.glob(f"{model_id}/*.safetensors")
247
-
248
- target_layers = list(range(n_layer)) if n_layer is not None else None
249
-
250
- for safetensor_file in safetensor_files:
251
- file_data = load_file(safetensor_file)
252
- for key, value in file_data.items():
253
- if target_layers is not None:
254
- parts = key.split(".")
255
-
256
- if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
257
- continue
258
-
259
- if key in model_params:
260
- model_params[key].data.copy_(value)
261
- elif key in model_buffers:
262
- model_buffers[key].data.copy_(value)
263
-
264
- return 0
265
-
266
199
  def __getattr__(self, __name: str) -> Any:
267
200
  """
268
201
  Special method to delegate attribute access to the original Huggingface LM class.
@@ -278,7 +211,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
278
211
  def redirect(func):
279
212
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
280
213
 
281
- val = getattr(self.original_cls, __name)
214
+ val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
282
215
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
283
216
  return redirect(val)
284
217
  return val
@@ -295,61 +228,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
295
228
 
296
229
  return model
297
230
 
298
- def validate_quantization_config(quantize_config):
299
- if quantize_config is not None:
300
- q_format = quantize_config.get("format")
301
- q_precision = quantize_config.get("precision")
302
-
303
- if q_format not in SUPPORTED_QUANTIZATIONS:
304
- raise ValueError(
305
- f"Invalid quantization format: {q_format}. "
306
- f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
307
- )
308
-
309
- if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
310
- raise ValueError(
311
- f"Invalid precision: {q_precision} for format: {q_format}. "
312
- f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
313
- )
314
-
315
- return quantize_config
316
-
317
- @classmethod
318
- def set_quantize_env(cls, quantize_config):
319
- RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
320
- quantize_config = cls.validate_quantization_config(quantize_config)
321
- if quantize_config is not None:
322
- q_precision = quantize_config.get("precision")
323
- quant_bits = q_precision.split("w")[1].split("a")[0]
324
- os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
325
- return RBLN_QUANT_BITS_ENV
326
- return None
327
-
328
- @classmethod
329
- def reset_quantize_env(cls, env_var_name):
330
- if env_var_name is not None and env_var_name in os.environ:
331
- del os.environ[env_var_name]
332
-
333
- @classmethod
334
- def manage_quantize_env(cls, func):
335
- @functools.wraps(func)
336
- def wrapper(*args, **kwargs):
337
- quantize_config = kwargs.get("quantize_config")
338
- quantize_env_var = cls.set_quantize_env(quantize_config)
339
- try:
340
- return func(*args, **kwargs)
341
- finally:
342
- cls.reset_quantize_env(quantize_env_var)
343
-
344
- return wrapper
345
-
346
231
  @classmethod
347
232
  def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
348
233
  wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
349
-
350
- # If the model wrapper supports rbln-custom-flash-attention
351
- if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
352
- wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
234
+ wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
235
+ wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
236
+ wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
353
237
 
354
238
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
355
239
 
@@ -359,69 +243,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
359
243
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
360
244
 
361
245
  rbln_compile_configs = rbln_config.compile_cfgs
362
- prefill_rbln_compile_config = rbln_compile_configs[0]
363
- dec_rbln_compile_config = rbln_compile_configs[1]
364
-
365
- @rbln_timer("JIT trace")
366
- def get_scripted_model():
367
- # This function is nested to dealloc the example inputs before compilation.
368
- # FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
369
- prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
370
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
371
-
372
- prefill_scripted_model = torch.jit.trace(
373
- wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
374
- )
375
- dec_scripted_model = torch.jit.trace(
376
- wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
377
- )
378
- return prefill_scripted_model, dec_scripted_model
246
+ prefill_compile_config = rbln_compile_configs[0]
247
+ dec_compile_config = rbln_compile_configs[1]
379
248
 
380
- prefill_scripted_model, dec_scripted_model = get_scripted_model()
249
+ context = CompileContext(use_weight_sharing=True)
381
250
 
382
- @rbln_timer("Model conversion")
383
- def scripted_model_to_ir():
384
- prefill_ir = rebel.torchscript_to_ir(
385
- prefill_scripted_model,
386
- input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
387
- )
388
- dec_ir = rebel.torchscript_to_ir(
389
- dec_scripted_model,
390
- input_names=[v[0] for v in dec_rbln_compile_config.input_info],
391
- )
392
- return prefill_ir, dec_ir
393
-
394
- prefill_ir, dec_ir = scripted_model_to_ir()
395
- # Caching prefill_decoder/decoder I/O
396
- cache_index_offset = 5
397
- connections = [
398
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
399
- for i in range(model.config.num_hidden_layers * 2)
400
- ]
251
+ # Here we use meta tensor, for the memory efficiency.
252
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
253
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
254
+
255
+ # Mark static tensors (self kv states)
256
+ static_tensors = {}
257
+ for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
258
+ if "past_key_values" in name:
259
+ static_tensors[name] = tensor
260
+ context.mark_static_address(tensor)
261
+
262
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
401
263
 
402
- # Extract quantize_config from rbln_config
403
264
  quantize_config = rbln_config.model_cfg.get("quantization", None)
404
265
 
405
- @cls.manage_quantize_env
266
+ @QuantizationManager.with_quantization_env
406
267
  def compile_model(*args, **kwargs):
407
- # Remove quantize_config from kwargs
408
- kwargs.pop("quantize_config", None)
409
-
410
- # Call rebel.compile with the updated kwargs
411
- return rebel.compile(*args, **kwargs)
412
-
413
- compiled_model = compile_model(
414
- prefill_ir,
415
- dec_ir,
416
- connections=connections,
417
- fusion=prefill_rbln_compile_config.fusion,
418
- npu=prefill_rbln_compile_config.npu,
419
- tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
420
- use_weight_sharing=True,
421
- quantize_config=quantize_config,
422
- )
268
+ wrapped_model.phase = "prefill"
269
+ compiled_prefill = RBLNModel.compile(
270
+ wrapped_model,
271
+ prefill_compile_config,
272
+ example_inputs=prefill_example_inputs,
273
+ compile_context=context,
274
+ )
423
275
 
424
- return compiled_model
276
+ wrapped_model.phase = "decode"
277
+ compiled_decoder = RBLNModel.compile(
278
+ wrapped_model,
279
+ dec_compile_config,
280
+ example_inputs=dec_example_inputs,
281
+ compile_context=context,
282
+ )
283
+ return {"prefill": compiled_prefill, "decoder": compiled_decoder}
284
+
285
+ return compile_model(quantize_config=quantize_config)
425
286
 
426
287
  @classmethod
427
288
  def _get_rbln_config(
@@ -432,10 +293,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
432
293
  ) -> RBLNConfig:
433
294
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
434
295
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
435
- rbln_quantization = rbln_kwargs.get("quantization", None)
436
296
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
437
-
438
- rbln_quantization = cls.validate_quantization_config(rbln_quantization)
297
+ rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
298
+ rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
299
+ rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
439
300
 
440
301
  prefill_chunk_size = 128
441
302
  if rbln_max_seq_len is None:
@@ -444,9 +305,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
444
305
  )
445
306
  if rbln_max_seq_len is None:
446
307
  raise ValueError("`rbln_max_seq_len` should be specified.")
308
+
447
309
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
448
310
  rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
449
311
 
312
+ rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
313
+ rbln_attn_impl=rbln_attn_impl,
314
+ rbln_kvcache_partition_len=rbln_kvcache_partition_len,
315
+ rbln_max_seq_len=rbln_max_seq_len,
316
+ )
317
+
450
318
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
451
319
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
452
320
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
@@ -472,9 +340,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
472
340
  [batch_size, query_length],
473
341
  "int32",
474
342
  ),
475
- ("batch_position", [], "int16"),
476
- ("query_idx", [], "int16"),
477
343
  ]
344
+ if query_length > 1:
345
+ input_info.extend(
346
+ [
347
+ ("batch_position", [], "int16"),
348
+ ("query_position", [], "int16"),
349
+ ]
350
+ )
478
351
 
479
352
  input_info.extend(
480
353
  [
@@ -507,12 +380,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
507
380
  hidden_size=hidden_size,
508
381
  )
509
382
 
510
- prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
511
- dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
383
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
384
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
512
385
 
513
386
  rbln_config = RBLNConfig(
514
387
  rbln_cls=cls.__name__,
515
- compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
388
+ compile_cfgs=[prefill_compile_config, dec_compile_config],
516
389
  rbln_kwargs=rbln_kwargs,
517
390
  )
518
391
 
@@ -522,6 +395,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
522
395
  "batch_size": rbln_batch_size,
523
396
  "prefill_chunk_size": prefill_chunk_size,
524
397
  "use_inputs_embeds": rbln_use_inputs_embeds,
398
+ "kvcache_partition_len": rbln_kvcache_partition_len,
399
+ "attn_impl": rbln_attn_impl,
525
400
  }
526
401
  )
527
402
 
@@ -532,12 +407,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
532
407
 
533
408
  @classmethod
534
409
  def _create_runtimes(
535
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
410
+ cls,
411
+ compiled_models: List[rebel.RBLNCompiledModel],
412
+ rbln_device_map: Dict[str, int],
413
+ activate_profiler: Optional[bool] = None,
536
414
  ) -> List[rebel.Runtime]:
537
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
415
+ if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
416
+ cls._raise_missing_compiled_file_error(["prefill", "decoder"])
417
+
538
418
  return [
539
- compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
540
- compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
419
+ compiled_models[0].create_runtime(
420
+ tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
421
+ ),
422
+ compiled_models[1].create_runtime(
423
+ tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
424
+ ),
541
425
  ]
542
426
 
543
427
  def get_decoder(self):
@@ -610,8 +494,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
610
494
  cache_position: Optional[torch.Tensor] = None,
611
495
  attention_mask: Optional[torch.LongTensor] = None,
612
496
  generate_idx: Optional[torch.Tensor] = None,
613
- # from llava_next forward args
614
- batch_idx: Optional[int] = None,
615
497
  **kwargs,
616
498
  ) -> Tuple[torch.FloatTensor]:
617
499
  # prefll
@@ -633,7 +515,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
633
515
  input_ids=input_tensor if inputs_embeds is None else None,
634
516
  inputs_embeds=input_tensor if inputs_embeds is not None else None,
635
517
  cache_position=cache_position,
636
- batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
518
+ batch_idx=b_idx,
637
519
  )
638
520
  logits.append(logit)
639
521
  logits = torch.cat(logits, dim=0)
@@ -671,12 +553,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
671
553
  ],
672
554
  dtype=torch.float32,
673
555
  device="cpu",
674
- ),
675
- torch.empty(size=[], dtype=torch.int16, device="cpu"),
556
+ )
676
557
  ]
677
558
 
678
559
  input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
679
560
  query_length = input_tensors.shape[1]
561
+ if query_length > self.max_seq_len:
562
+ raise ValueError(
563
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
564
+ )
565
+
680
566
  _attention_mask = self.prefill_attention_mask.clone()
681
567
 
682
568
  for step in range(0, query_length, self.prefill_chunk_size):
@@ -709,15 +595,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
709
595
  _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
710
596
  _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
711
597
 
712
- query_idx = (query_length - 1) % self.prefill_chunk_size
598
+ query_position = (query_length - 1) % self.prefill_chunk_size
713
599
 
714
- logits, _ = self.prefill_decoder(
600
+ logits = self.prefill_decoder(
715
601
  input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
716
602
  inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
717
603
  attention_mask=_attention_mask.contiguous(),
718
604
  cache_position=_cache_position.contiguous(),
719
605
  batch_position=torch.tensor(batch_idx, dtype=torch.int16),
720
- query_idx=torch.tensor(query_idx, dtype=torch.int16),
606
+ query_position=torch.tensor(query_position, dtype=torch.int16),
721
607
  out=out_buffers,
722
608
  )
723
609
 
@@ -734,48 +620,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
734
620
  cache_position: torch.Tensor = None,
735
621
  ) -> torch.FloatTensor:
736
622
  input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
623
+ if input_tensors is None:
624
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
737
625
 
738
626
  batch_size = input_tensors.shape[0]
627
+ if batch_size != self.batch_size:
628
+ raise RuntimeError(
629
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
630
+ )
631
+
632
+ if batch_size != cache_position.shape[0]:
633
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
739
634
 
740
635
  for b_idx in range(batch_size):
741
636
  decoding_step = cache_position[b_idx].item()
637
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
638
+ raise ValueError(
639
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
640
+ )
742
641
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
743
-
744
- logits, _ = self.decoder(
642
+ logits = self.decoder(
745
643
  input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
746
644
  inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
747
645
  attention_mask=self.dec_attn_mask.contiguous(),
748
646
  cache_position=cache_position.contiguous(),
749
- batch_position=torch.tensor(0, dtype=torch.int16),
750
- query_idx=torch.tensor(0, dtype=torch.int16),
751
647
  )
752
648
 
753
649
  return logits
754
-
755
- def vllm_forward(
756
- self,
757
- input_ids: torch.LongTensor = None,
758
- inputs_embeds: torch.Tensor = None,
759
- cache_position: torch.Tensor = None,
760
- batch_idx: Optional[int] = None,
761
- **kwargs,
762
- ) -> Tuple[torch.FloatTensor]:
763
- # prefll
764
- if cache_position.shape[-1] > 1:
765
- logits = self._forward_prefill(
766
- input_ids=input_ids,
767
- inputs_embeds=inputs_embeds,
768
- cache_position=cache_position,
769
- batch_idx=batch_idx,
770
- )
771
- # decoder
772
- else:
773
- logits = self._forward_decoder(
774
- input_ids=input_ids,
775
- inputs_embeds=inputs_embeds,
776
- cache_position=cache_position,
777
- )
778
-
779
- return RBLNDecoderOnlyOutput(
780
- logits=logits,
781
- )
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
27
27
  from transformers import AutoModelForDepthEstimation
28
28
  from transformers.modeling_outputs import DepthEstimatorOutput
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33