optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -20,26 +20,24 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
- import functools
24
- import glob
25
- import os
26
- from abc import ABC
23
+ import inspect
27
24
  from dataclasses import dataclass
28
25
  from pathlib import Path
29
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
30
27
 
31
- import rebel # noqa: F401
32
- import torch # noqa: F401
33
- from safetensors.torch import load_file
28
+ import rebel
29
+ import torch
34
30
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
35
31
  from transformers.modeling_utils import no_init_weights
36
32
  from transformers.utils import ModelOutput
37
33
 
38
- from ....modeling_base import RBLNModel
34
+ from ....modeling import RBLNModel
39
35
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
40
36
  from ....utils.logging import get_logger
41
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
38
  from ....utils.timer_utils import rbln_timer
39
+ from ...utils.rbln_quantization import QuantizationManager
40
+ from .decoderonly_architecture import DecoderOnlyWrapper
43
41
 
44
42
 
45
43
  logger = get_logger()
@@ -52,12 +50,6 @@ if TYPE_CHECKING:
52
50
  PretrainedConfig,
53
51
  )
54
52
 
55
- SUPPORTED_QUANTIZATIONS = {
56
- "rbln": [
57
- "w4a16",
58
- ],
59
- }
60
-
61
53
 
62
54
  class RBLNRuntimeModel(RBLNPytorchRuntime):
63
55
  mandatory_members = ["main_input_name", "embed_tokens"]
@@ -102,19 +94,30 @@ class RBLNDecoderOnlyOutput(ModelOutput):
102
94
  generate_idx: torch.Tensor = None
103
95
 
104
96
 
105
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
97
+ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
106
98
  """
107
- The DecoderOnly Model transformer with a language modeling head (linear layer) on top.
108
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
109
-
110
- A class to convert and run pre-trained transformers based DecoderOnlyForCausalLM model on RBLN devices.
111
- It implements the methods to convert a pre-trained transformers DecoderOnlyForCausalLM model into a RBLN transformer model by:
112
- - transferring the checkpoint weights of the original into an optimized RBLN graph,
113
- - compiling the resulting graph using the RBLN compiler.
99
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
100
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
101
+
102
+ The class provides core functionality for:
103
+ 1. Converting pre-trained transformer models to RBLN-optimized format
104
+ 2. Handling the compilation process for RBLN devices
105
+ 3. Managing inference operations for causal language modeling
106
+
107
+ This class inherits from RBLNModel and implements specific methods required for
108
+ decoder-only architectures and causal language modeling tasks.
109
+
110
+ Note:
111
+ - This class is designed to be subclassed by specific model implementations
112
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
113
+ - Subclasses should implement model-specific conversion logic.
114
+ - The class handles RBLN-specific optimizations automatically during compilation
114
115
  """
115
116
 
116
117
  main_input_name = "input_ids"
117
118
  auto_model_class = AutoModelForCausalLM
119
+ _decoder_wrapper_cls = DecoderOnlyWrapper
120
+ _use_rotary_emb = True
118
121
 
119
122
  def __post_init__(self, **kwargs):
120
123
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -173,6 +176,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
173
176
  def get_quantized_model(
174
177
  cls,
175
178
  model_id: str,
179
+ config: Optional[PretrainedConfig] = None,
176
180
  use_auth_token: Optional[Union[bool, str]] = None,
177
181
  revision: Optional[str] = None,
178
182
  force_download: bool = False,
@@ -182,56 +186,47 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
182
186
  trust_remote_code: bool = False,
183
187
  **kwargs,
184
188
  ):
185
- from ...utils.rbln_quantization import update_layers_to_quantized
189
+ from ...utils.rbln_quantization import prepare_model_for_quantization
186
190
 
187
191
  kwargs = cls.update_kwargs(kwargs)
188
192
 
189
- config = AutoConfig.from_pretrained(
190
- model_id,
191
- use_auth_token=use_auth_token,
192
- revision=revision,
193
- force_download=force_download,
194
- cache_dir=cache_dir,
195
- trust_remote_code=trust_remote_code,
196
- **kwargs,
197
- )
193
+ if config is None:
194
+ config = AutoConfig.from_pretrained(
195
+ model_id,
196
+ use_auth_token=use_auth_token,
197
+ revision=revision,
198
+ force_download=force_download,
199
+ cache_dir=cache_dir,
200
+ trust_remote_code=trust_remote_code,
201
+ **kwargs,
202
+ )
198
203
 
199
204
  with no_init_weights():
200
205
  model = AutoModelForCausalLM.from_config(config)
201
206
 
202
- update_layers_to_quantized(model)
203
-
204
- n_layer = kwargs.get("num_hidden_layers", None)
205
- cls._load_weights_directly_to_model(model, model_id, n_layer)
207
+ prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
206
208
 
207
209
  return model
208
210
 
209
- def _load_weights_directly_to_model(model, model_id, n_layer=None):
211
+ def __getattr__(self, __name: str) -> Any:
210
212
  """
211
- Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
213
+ Special method to delegate attribute access to the original Huggingface LM class.
214
+ This method is called when an attribute is not found in the current instance's dictionary.
215
+ It enables transparent access to the original model's attributes and methods while maintaining
216
+ proper method binding.
217
+
218
+ The method implements a delegation pattern that:
219
+ 1. For methods: Creates a wrapper that properly binds 'self' to method calls
220
+ 2. For other attributes: Returns them directly from the original class
212
221
  """
213
222
 
214
- model_params = dict(model.named_parameters(recurse=True))
215
- model_buffers = dict(model.named_buffers(recurse=True))
216
- safetensor_files = glob.glob(f"{model_id}/*.safetensors")
217
-
218
- target_layers = list(range(n_layer)) if n_layer is not None else None
223
+ def redirect(func):
224
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
219
225
 
220
- for safetensor_file in safetensor_files:
221
- file_data = load_file(safetensor_file)
222
- for key, value in file_data.items():
223
- if target_layers is not None:
224
- parts = key.split(".")
225
-
226
- if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
227
- continue
228
-
229
- if key in model_params:
230
- model_params[key].data.copy_(value)
231
- elif key in model_buffers:
232
- model_buffers[key].data.copy_(value)
233
-
234
- return 0
226
+ val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
227
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
228
+ return redirect(val)
229
+ return val
235
230
 
236
231
  @classmethod
237
232
  def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
@@ -245,53 +240,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
245
240
 
246
241
  return model
247
242
 
248
- def validate_quantization_config(quantize_config):
249
- if quantize_config is not None:
250
- q_format = quantize_config.get("format")
251
- q_precision = quantize_config.get("precision")
252
-
253
- if q_format not in SUPPORTED_QUANTIZATIONS:
254
- raise ValueError(
255
- f"Invalid quantization format: {q_format}. "
256
- f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
257
- )
258
-
259
- if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
260
- raise ValueError(
261
- f"Invalid precision: {q_precision} for format: {q_format}. "
262
- f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
263
- )
264
-
265
- return quantize_config
266
-
267
243
  @classmethod
268
- def set_quantize_env(cls, quantize_config):
269
- RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
270
- quantize_config = cls.validate_quantization_config(quantize_config)
271
- if quantize_config is not None:
272
- q_precision = quantize_config.get("precision")
273
- quant_bits = q_precision.split("w")[1].split("a")[0]
274
- os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
275
- return RBLN_QUANT_BITS_ENV
276
- return None
244
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
245
+ wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
277
246
 
278
- @classmethod
279
- def reset_quantize_env(cls, env_var_name):
280
- if env_var_name is not None and env_var_name in os.environ:
281
- del os.environ[env_var_name]
247
+ # If the model wrapper supports rbln-custom-flash-attention
248
+ if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
249
+ wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
282
250
 
283
- @classmethod
284
- def manage_quantize_env(cls, func):
285
- @functools.wraps(func)
286
- def wrapper(*args, **kwargs):
287
- quantize_config = kwargs.get("quantize_config")
288
- quantize_env_var = cls.set_quantize_env(quantize_config)
289
- try:
290
- return func(*args, **kwargs)
291
- finally:
292
- cls.reset_quantize_env(quantize_env_var)
293
-
294
- return wrapper
251
+ wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
252
+
253
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
295
254
 
296
255
  @classmethod
297
256
  @torch.inference_mode()
@@ -305,15 +264,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
305
264
  @rbln_timer("JIT trace")
306
265
  def get_scripted_model():
307
266
  # This function is nested to dealloc the example inputs before compilation.
267
+ # FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
308
268
  prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
309
- dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
310
-
311
- batch_index = 3
312
- dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
269
+ dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
313
270
 
271
+ wrapped_model.phase = "prefill"
314
272
  prefill_scripted_model = torch.jit.trace(
315
273
  wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
316
274
  )
275
+ wrapped_model.phase = "decode"
317
276
  dec_scripted_model = torch.jit.trace(
318
277
  wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
319
278
  )
@@ -336,6 +295,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
336
295
  prefill_ir, dec_ir = scripted_model_to_ir()
337
296
  # Caching prefill_decoder/decoder I/O
338
297
  cache_index_offset = 5
298
+
339
299
  connections = [
340
300
  (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
341
301
  for i in range(model.config.num_hidden_layers * 2)
@@ -344,7 +304,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
344
304
  # Extract quantize_config from rbln_config
345
305
  quantize_config = rbln_config.model_cfg.get("quantization", None)
346
306
 
347
- @cls.manage_quantize_env
307
+ @QuantizationManager.with_quantization_env
348
308
  def compile_model(*args, **kwargs):
349
309
  # Remove quantize_config from kwargs
350
310
  kwargs.pop("quantize_config", None)
@@ -374,10 +334,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
374
334
  ) -> RBLNConfig:
375
335
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
376
336
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
377
- rbln_quantization = rbln_kwargs.get("quantization", None)
378
337
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
379
-
380
- rbln_quantization = cls.validate_quantization_config(rbln_quantization)
338
+ rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
381
339
 
382
340
  prefill_chunk_size = 128
383
341
  if rbln_max_seq_len is None:
@@ -552,8 +510,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
552
510
  cache_position: Optional[torch.Tensor] = None,
553
511
  attention_mask: Optional[torch.LongTensor] = None,
554
512
  generate_idx: Optional[torch.Tensor] = None,
555
- # from llava_next forward args
556
- batch_idx: Optional[int] = None,
557
513
  **kwargs,
558
514
  ) -> Tuple[torch.FloatTensor]:
559
515
  # prefll
@@ -575,7 +531,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
575
531
  input_ids=input_tensor if inputs_embeds is None else None,
576
532
  inputs_embeds=input_tensor if inputs_embeds is not None else None,
577
533
  cache_position=cache_position,
578
- batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
534
+ batch_idx=b_idx,
579
535
  )
580
536
  logits.append(logit)
581
537
  logits = torch.cat(logits, dim=0)
@@ -676,11 +632,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
676
632
  cache_position: torch.Tensor = None,
677
633
  ) -> torch.FloatTensor:
678
634
  input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
635
+ if input_tensors is None:
636
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
679
637
 
680
638
  batch_size = input_tensors.shape[0]
639
+ if batch_size != self.batch_size:
640
+ raise RuntimeError(
641
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
642
+ )
643
+
644
+ if batch_size != cache_position.shape[0]:
645
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
681
646
 
682
647
  for b_idx in range(batch_size):
683
648
  decoding_step = cache_position[b_idx].item()
649
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
650
+ raise ValueError(
651
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
652
+ )
684
653
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
685
654
 
686
655
  logits, _ = self.decoder(
@@ -693,31 +662,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
693
662
  )
694
663
 
695
664
  return logits
696
-
697
- def vllm_forward(
698
- self,
699
- input_ids: torch.LongTensor = None,
700
- inputs_embeds: torch.Tensor = None,
701
- cache_position: torch.Tensor = None,
702
- batch_idx: Optional[int] = None,
703
- **kwargs,
704
- ) -> Tuple[torch.FloatTensor]:
705
- # prefll
706
- if cache_position.shape[-1] > 1:
707
- logits = self._forward_prefill(
708
- input_ids=input_ids,
709
- inputs_embeds=inputs_embeds,
710
- cache_position=cache_position,
711
- batch_idx=batch_idx,
712
- )
713
- # decoder
714
- else:
715
- logits = self._forward_decoder(
716
- input_ids=input_ids,
717
- inputs_embeds=inputs_embeds,
718
- cache_position=cache_position,
719
- )
720
-
721
- return RBLNDecoderOnlyOutput(
722
- logits=logits,
723
- )
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
27
27
  from transformers import AutoModelForDepthEstimation
28
28
  from transformers.modeling_outputs import DepthEstimatorOutput
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -20,53 +20,78 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ from typing import TYPE_CHECKING
23
24
 
25
+ import torch.nn as nn
24
26
 
25
- from ...models.decoderonly import (
27
+ from ....utils import logging
28
+ from ...models.decoderonly.decoderonly_architecture import (
26
29
  DecoderOnlyAttention,
27
- DecoderOnlyDecoderLayer,
30
+ DecoderOnlyFlashAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
28
33
  DecoderOnlyModel,
29
34
  DecoderOnlyWrapper,
30
35
  )
31
36
 
32
37
 
38
+ if TYPE_CHECKING:
39
+ from transformers import PreTrainedModel as ExaoneForCausalLM
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
33
44
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
34
45
  """A wrapper class for the Exaone model with a language modeling head."""
35
46
 
36
- def __init__(self, model, max_seq_len):
37
- super(DecoderOnlyWrapper, self).__init__()
38
- self.config = model.config
39
- self.model = self.convert_attribute_name(model.transformer)
40
- self.lm_head = model.lm_head
41
- self.head_dim = self.config.hidden_size // self.config.num_attention_heads
42
- self.max_position_embeddings = (
43
- self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
44
- )
45
- self.max_seq_len = max_seq_len
46
- self.rope_scaling = getattr(self.config, "rope_scaling", None)
47
- self.rotary_emb = self._init_rope()
48
-
49
- @staticmethod
50
- def convert_attribute_name(model):
51
- model.embed_tokens = model.wte
52
- model.norm = model.ln_f
53
- model.layers = model.h
54
-
55
- for layer in model.layers:
56
- layer.input_layernorm = layer.ln_1
57
- layer.self_attn = layer.attn.attention
58
- layer.post_attention_layernorm = layer.ln_2
59
- layer.self_attn.o_proj = layer.self_attn.out_proj
60
-
61
- return model
62
-
63
- def get_forward_dict(self):
64
- forward_dict = {}
65
- forward_dict.update(
66
- {
67
- "wrapper": DecoderOnlyModel.forward,
68
- "model": DecoderOnlyDecoderLayer.forward,
69
- "decoder_layer": DecoderOnlyAttention.forward,
70
- }
71
- )
72
- return forward_dict
47
+ def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
48
+ new_layers = []
49
+ for layer in causal_lm.transformer.h:
50
+ if self.attn_impl == "eager":
51
+ new_self_attn = ExaoneAttention(layer.attn.attention)
52
+ elif self.attn_impl == "flash_attn":
53
+ new_self_attn = ExaoneFlashAttention(
54
+ layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
55
+ )
56
+ else:
57
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
58
+
59
+ new_layer = ExaoneLayer(layer, new_self_attn)
60
+ new_layers.append(new_layer)
61
+ new_model = ExaoneModel(causal_lm.transformer, new_layers)
62
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
63
+ return new_causal_lm
64
+
65
+
66
+ class ExaoneModel(DecoderOnlyModel):
67
+ def get_embedding(self) -> nn.Embedding:
68
+ return self._original_mod.wte
69
+
70
+ def get_last_layernorm(self) -> nn.LayerNorm:
71
+ return self._original_mod.ln_f
72
+
73
+
74
+ class ExaoneLayer(DecoderOnlyLayer):
75
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
76
+ return self._original_mod.ln_1
77
+
78
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
79
+ return self._original_mod.ln_2
80
+
81
+
82
+ class ExaoneAttention(DecoderOnlyAttention):
83
+ def __post_init__(self):
84
+ self.q_proj = self._original_mod.q_proj
85
+ self.k_proj = self._original_mod.k_proj
86
+ self.v_proj = self._original_mod.v_proj
87
+ self.o_proj = self._original_mod.out_proj
88
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
89
+
90
+
91
+ class ExaoneFlashAttention(DecoderOnlyFlashAttention):
92
+ def __post_init__(self):
93
+ self.q_proj = self._original_mod.q_proj
94
+ self.k_proj = self._original_mod.k_proj
95
+ self.v_proj = self._original_mod.v_proj
96
+ self.o_proj = self._original_mod.out_proj
97
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -21,21 +21,15 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
24
 
28
- from ....modeling_config import RBLNConfig
25
+ from transformers import AutoModelForCausalLM
26
+
27
+ from ....utils import logging
29
28
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
30
29
  from .exaone_architecture import ExaoneForCausalLMWrapper
31
- from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
32
30
 
33
31
 
34
- logger = logging.getLogger(__name__)
35
- if TYPE_CHECKING:
36
- from transformers import (
37
- PreTrainedModel,
38
- )
32
+ logger = logging.get_logger(__name__)
39
33
 
40
34
 
41
35
  class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
@@ -52,25 +46,8 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
52
46
 
53
47
  """
54
48
 
55
- @classmethod
56
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
57
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
58
- return ExaoneForCausalLMWrapper(model, rbln_max_seq_len).eval()
59
-
60
- def __getattr__(self, __name: str) -> Any:
61
- """This is the key method to implement RBLN-Exaone.
62
-
63
- Returns:
64
- Any: Exaone's corresponding method
65
- """
66
-
67
- def redirect(func):
68
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
69
-
70
- val = getattr(ExaoneForCausalLM, __name)
71
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
72
- return redirect(val)
73
- return val
49
+ _decoder_wrapper_cls = ExaoneForCausalLMWrapper
50
+ _hf_class = AutoModelForCausalLM
74
51
 
75
52
  @classmethod
76
53
  def from_pretrained(cls, *args, **kwargs):