optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -20,27 +20,23 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
- import functools
24
- import glob
25
23
  import inspect
26
- import os
27
24
  from dataclasses import dataclass
28
25
  from pathlib import Path
29
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
30
27
 
31
28
  import rebel
32
29
  import torch
33
- import transformers
34
- from safetensors.torch import load_file
35
30
  from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
36
31
  from transformers.modeling_utils import no_init_weights
37
32
  from transformers.utils import ModelOutput
38
33
 
39
- from ....modeling_base import RBLNModel
34
+ from ....modeling import RBLNModel
40
35
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
36
  from ....utils.logging import get_logger
42
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
43
38
  from ....utils.timer_utils import rbln_timer
39
+ from ...utils.rbln_quantization import QuantizationManager
44
40
  from .decoderonly_architecture import DecoderOnlyWrapper
45
41
 
46
42
 
@@ -54,12 +50,6 @@ if TYPE_CHECKING:
54
50
  PretrainedConfig,
55
51
  )
56
52
 
57
- SUPPORTED_QUANTIZATIONS = {
58
- "rbln": [
59
- "w4a16",
60
- ],
61
- }
62
-
63
53
 
64
54
  class RBLNRuntimeModel(RBLNPytorchRuntime):
65
55
  mandatory_members = ["main_input_name", "embed_tokens"]
@@ -127,24 +117,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
127
117
  main_input_name = "input_ids"
128
118
  auto_model_class = AutoModelForCausalLM
129
119
  _decoder_wrapper_cls = DecoderOnlyWrapper
130
- _original_cls = None
131
-
132
- @classmethod
133
- @property
134
- def original_cls(cls):
135
- """
136
- Lazily loads and caches the corresponding Hugging Face model class.
137
- Removes 'RBLN' prefix from the class name to get the original class name
138
- (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
139
- the transformers module.
140
-
141
- Returns:
142
- type: The original Hugging Face model class
143
- """
144
- if cls._original_cls is None:
145
- hf_original_cls_name = cls.__name__[4:]
146
- cls._original_cls = getattr(transformers, hf_original_cls_name)
147
- return cls._original_cls
120
+ _use_rotary_emb = True
148
121
 
149
122
  def __post_init__(self, **kwargs):
150
123
  self.batch_size = self.rbln_config.model_cfg["batch_size"]
@@ -203,6 +176,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
203
176
  def get_quantized_model(
204
177
  cls,
205
178
  model_id: str,
179
+ config: Optional[PretrainedConfig] = None,
206
180
  use_auth_token: Optional[Union[bool, str]] = None,
207
181
  revision: Optional[str] = None,
208
182
  force_download: bool = False,
@@ -212,57 +186,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
212
186
  trust_remote_code: bool = False,
213
187
  **kwargs,
214
188
  ):
215
- from ...utils.rbln_quantization import update_layers_to_quantized
189
+ from ...utils.rbln_quantization import prepare_model_for_quantization
216
190
 
217
191
  kwargs = cls.update_kwargs(kwargs)
218
192
 
219
- config = AutoConfig.from_pretrained(
220
- model_id,
221
- use_auth_token=use_auth_token,
222
- revision=revision,
223
- force_download=force_download,
224
- cache_dir=cache_dir,
225
- trust_remote_code=trust_remote_code,
226
- **kwargs,
227
- )
193
+ if config is None:
194
+ config = AutoConfig.from_pretrained(
195
+ model_id,
196
+ use_auth_token=use_auth_token,
197
+ revision=revision,
198
+ force_download=force_download,
199
+ cache_dir=cache_dir,
200
+ trust_remote_code=trust_remote_code,
201
+ **kwargs,
202
+ )
228
203
 
229
204
  with no_init_weights():
230
205
  model = AutoModelForCausalLM.from_config(config)
231
206
 
232
- update_layers_to_quantized(model)
233
-
234
- n_layer = kwargs.get("num_hidden_layers", None)
235
- cls._load_weights_directly_to_model(model, model_id, n_layer)
207
+ prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
236
208
 
237
209
  return model
238
210
 
239
- def _load_weights_directly_to_model(model, model_id, n_layer=None):
240
- """
241
- Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
242
- """
243
-
244
- model_params = dict(model.named_parameters(recurse=True))
245
- model_buffers = dict(model.named_buffers(recurse=True))
246
- safetensor_files = glob.glob(f"{model_id}/*.safetensors")
247
-
248
- target_layers = list(range(n_layer)) if n_layer is not None else None
249
-
250
- for safetensor_file in safetensor_files:
251
- file_data = load_file(safetensor_file)
252
- for key, value in file_data.items():
253
- if target_layers is not None:
254
- parts = key.split(".")
255
-
256
- if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
257
- continue
258
-
259
- if key in model_params:
260
- model_params[key].data.copy_(value)
261
- elif key in model_buffers:
262
- model_buffers[key].data.copy_(value)
263
-
264
- return 0
265
-
266
211
  def __getattr__(self, __name: str) -> Any:
267
212
  """
268
213
  Special method to delegate attribute access to the original Huggingface LM class.
@@ -278,7 +223,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
278
223
  def redirect(func):
279
224
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
280
225
 
281
- val = getattr(self.original_cls, __name)
226
+ val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
282
227
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
283
228
  return redirect(val)
284
229
  return val
@@ -295,54 +240,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
295
240
 
296
241
  return model
297
242
 
298
- def validate_quantization_config(quantize_config):
299
- if quantize_config is not None:
300
- q_format = quantize_config.get("format")
301
- q_precision = quantize_config.get("precision")
302
-
303
- if q_format not in SUPPORTED_QUANTIZATIONS:
304
- raise ValueError(
305
- f"Invalid quantization format: {q_format}. "
306
- f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
307
- )
308
-
309
- if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
310
- raise ValueError(
311
- f"Invalid precision: {q_precision} for format: {q_format}. "
312
- f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
313
- )
314
-
315
- return quantize_config
316
-
317
- @classmethod
318
- def set_quantize_env(cls, quantize_config):
319
- RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
320
- quantize_config = cls.validate_quantization_config(quantize_config)
321
- if quantize_config is not None:
322
- q_precision = quantize_config.get("precision")
323
- quant_bits = q_precision.split("w")[1].split("a")[0]
324
- os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
325
- return RBLN_QUANT_BITS_ENV
326
- return None
327
-
328
- @classmethod
329
- def reset_quantize_env(cls, env_var_name):
330
- if env_var_name is not None and env_var_name in os.environ:
331
- del os.environ[env_var_name]
332
-
333
- @classmethod
334
- def manage_quantize_env(cls, func):
335
- @functools.wraps(func)
336
- def wrapper(*args, **kwargs):
337
- quantize_config = kwargs.get("quantize_config")
338
- quantize_env_var = cls.set_quantize_env(quantize_config)
339
- try:
340
- return func(*args, **kwargs)
341
- finally:
342
- cls.reset_quantize_env(quantize_env_var)
343
-
344
- return wrapper
345
-
346
243
  @classmethod
347
244
  def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
348
245
  wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
@@ -351,6 +248,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
351
248
  if "kvcache_partition_len" in inspect.signature(cls._decoder_wrapper_cls.__init__).parameters:
352
249
  wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
353
250
 
251
+ wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
252
+
354
253
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
355
254
 
356
255
  @classmethod
@@ -369,9 +268,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
369
268
  prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
370
269
  dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
371
270
 
271
+ wrapped_model.phase = "prefill"
372
272
  prefill_scripted_model = torch.jit.trace(
373
273
  wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
374
274
  )
275
+ wrapped_model.phase = "decode"
375
276
  dec_scripted_model = torch.jit.trace(
376
277
  wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
377
278
  )
@@ -394,6 +295,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
394
295
  prefill_ir, dec_ir = scripted_model_to_ir()
395
296
  # Caching prefill_decoder/decoder I/O
396
297
  cache_index_offset = 5
298
+
397
299
  connections = [
398
300
  (prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
399
301
  for i in range(model.config.num_hidden_layers * 2)
@@ -402,7 +304,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
402
304
  # Extract quantize_config from rbln_config
403
305
  quantize_config = rbln_config.model_cfg.get("quantization", None)
404
306
 
405
- @cls.manage_quantize_env
307
+ @QuantizationManager.with_quantization_env
406
308
  def compile_model(*args, **kwargs):
407
309
  # Remove quantize_config from kwargs
408
310
  kwargs.pop("quantize_config", None)
@@ -432,10 +334,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
432
334
  ) -> RBLNConfig:
433
335
  rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
434
336
  rbln_batch_size = rbln_kwargs.get("batch_size", None)
435
- rbln_quantization = rbln_kwargs.get("quantization", None)
436
337
  rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
437
-
438
- rbln_quantization = cls.validate_quantization_config(rbln_quantization)
338
+ rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
439
339
 
440
340
  prefill_chunk_size = 128
441
341
  if rbln_max_seq_len is None:
@@ -610,8 +510,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
610
510
  cache_position: Optional[torch.Tensor] = None,
611
511
  attention_mask: Optional[torch.LongTensor] = None,
612
512
  generate_idx: Optional[torch.Tensor] = None,
613
- # from llava_next forward args
614
- batch_idx: Optional[int] = None,
615
513
  **kwargs,
616
514
  ) -> Tuple[torch.FloatTensor]:
617
515
  # prefll
@@ -633,7 +531,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
633
531
  input_ids=input_tensor if inputs_embeds is None else None,
634
532
  inputs_embeds=input_tensor if inputs_embeds is not None else None,
635
533
  cache_position=cache_position,
636
- batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
534
+ batch_idx=b_idx,
637
535
  )
638
536
  logits.append(logit)
639
537
  logits = torch.cat(logits, dim=0)
@@ -734,11 +632,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
734
632
  cache_position: torch.Tensor = None,
735
633
  ) -> torch.FloatTensor:
736
634
  input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
635
+ if input_tensors is None:
636
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
737
637
 
738
638
  batch_size = input_tensors.shape[0]
639
+ if batch_size != self.batch_size:
640
+ raise RuntimeError(
641
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
642
+ )
643
+
644
+ if batch_size != cache_position.shape[0]:
645
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
739
646
 
740
647
  for b_idx in range(batch_size):
741
648
  decoding_step = cache_position[b_idx].item()
649
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
650
+ raise ValueError(
651
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
652
+ )
742
653
  self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
743
654
 
744
655
  logits, _ = self.decoder(
@@ -751,31 +662,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
751
662
  )
752
663
 
753
664
  return logits
754
-
755
- def vllm_forward(
756
- self,
757
- input_ids: torch.LongTensor = None,
758
- inputs_embeds: torch.Tensor = None,
759
- cache_position: torch.Tensor = None,
760
- batch_idx: Optional[int] = None,
761
- **kwargs,
762
- ) -> Tuple[torch.FloatTensor]:
763
- # prefll
764
- if cache_position.shape[-1] > 1:
765
- logits = self._forward_prefill(
766
- input_ids=input_ids,
767
- inputs_embeds=inputs_embeds,
768
- cache_position=cache_position,
769
- batch_idx=batch_idx,
770
- )
771
- # decoder
772
- else:
773
- logits = self._forward_decoder(
774
- input_ids=input_ids,
775
- inputs_embeds=inputs_embeds,
776
- cache_position=cache_position,
777
- )
778
-
779
- return RBLNDecoderOnlyOutput(
780
- logits=logits,
781
- )
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
27
27
  from transformers import AutoModelForDepthEstimation
28
28
  from transformers.modeling_outputs import DepthEstimatorOutput
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -20,62 +20,78 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
- import torch
23
+ from typing import TYPE_CHECKING
24
+
25
+ import torch.nn as nn
24
26
 
25
27
  from ....utils import logging
26
- from ...models.decoderonly import (
28
+ from ...models.decoderonly.decoderonly_architecture import (
27
29
  DecoderOnlyAttention,
28
- DecoderOnlyDecoderLayer,
30
+ DecoderOnlyFlashAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
29
33
  DecoderOnlyModel,
30
34
  DecoderOnlyWrapper,
31
- RotaryEmbedding,
32
35
  )
33
36
 
34
37
 
38
+ if TYPE_CHECKING:
39
+ from transformers import PreTrainedModel as ExaoneForCausalLM
40
+
35
41
  logger = logging.get_logger(__name__)
36
42
 
37
43
 
38
44
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
39
45
  """A wrapper class for the Exaone model with a language modeling head."""
40
46
 
41
- def __init__(self, model, max_seq_len, kvcache_partition_len=None):
42
- super(DecoderOnlyWrapper, self).__init__()
43
- self.config = model.config
44
- self.model = self.convert_attribute_name(model.transformer)
45
- self.lm_head = model.lm_head
46
- self.rotary_emb = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
47
-
48
- if kvcache_partition_len is not None:
49
- # WORKAROUND : for passing partition length as a value to the rbln compiler.
50
- # What is actually used is the shape of this tensor.
51
- self.kvcache_partition_size = torch.zeros(kvcache_partition_len, dtype=torch.int32)
52
- self.attn_implementation = "flash_attn_rbln"
53
- logger.info(f"Using rbln-flash-attention. (partition length : {kvcache_partition_len})")
54
- else:
55
- self.kvcache_partition_size = None
56
- self.attn_implementation = "eager"
57
-
58
- @staticmethod
59
- def convert_attribute_name(model):
60
- model.embed_tokens = model.wte
61
- model.norm = model.ln_f
62
- model.layers = model.h
63
-
64
- for layer in model.layers:
65
- layer.input_layernorm = layer.ln_1
66
- layer.self_attn = layer.attn.attention
67
- layer.post_attention_layernorm = layer.ln_2
68
- layer.self_attn.o_proj = layer.self_attn.out_proj
69
-
70
- return model
71
-
72
- def get_forward_dict(self):
73
- forward_dict = {}
74
- forward_dict.update(
75
- {
76
- "wrapper": DecoderOnlyModel.forward,
77
- "model": DecoderOnlyDecoderLayer.forward,
78
- "decoder_layer": DecoderOnlyAttention.forward,
79
- }
80
- )
81
- return forward_dict
47
+ def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
48
+ new_layers = []
49
+ for layer in causal_lm.transformer.h:
50
+ if self.attn_impl == "eager":
51
+ new_self_attn = ExaoneAttention(layer.attn.attention)
52
+ elif self.attn_impl == "flash_attn":
53
+ new_self_attn = ExaoneFlashAttention(
54
+ layer.attn.attention, kvcache_partition_len=self.kvcache_partition_len
55
+ )
56
+ else:
57
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
58
+
59
+ new_layer = ExaoneLayer(layer, new_self_attn)
60
+ new_layers.append(new_layer)
61
+ new_model = ExaoneModel(causal_lm.transformer, new_layers)
62
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
63
+ return new_causal_lm
64
+
65
+
66
+ class ExaoneModel(DecoderOnlyModel):
67
+ def get_embedding(self) -> nn.Embedding:
68
+ return self._original_mod.wte
69
+
70
+ def get_last_layernorm(self) -> nn.LayerNorm:
71
+ return self._original_mod.ln_f
72
+
73
+
74
+ class ExaoneLayer(DecoderOnlyLayer):
75
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
76
+ return self._original_mod.ln_1
77
+
78
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
79
+ return self._original_mod.ln_2
80
+
81
+
82
+ class ExaoneAttention(DecoderOnlyAttention):
83
+ def __post_init__(self):
84
+ self.q_proj = self._original_mod.q_proj
85
+ self.k_proj = self._original_mod.k_proj
86
+ self.v_proj = self._original_mod.v_proj
87
+ self.o_proj = self._original_mod.out_proj
88
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
89
+
90
+
91
+ class ExaoneFlashAttention(DecoderOnlyFlashAttention):
92
+ def __post_init__(self):
93
+ self.q_proj = self._original_mod.q_proj
94
+ self.k_proj = self._original_mod.k_proj
95
+ self.v_proj = self._original_mod.v_proj
96
+ self.o_proj = self._original_mod.out_proj
97
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
@@ -21,10 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
24
27
  from ....utils import logging
25
28
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
26
29
  from .exaone_architecture import ExaoneForCausalLMWrapper
27
- from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
28
30
 
29
31
 
30
32
  logger = logging.get_logger(__name__)
@@ -45,7 +47,7 @@ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
45
47
  """
46
48
 
47
49
  _decoder_wrapper_cls = ExaoneForCausalLMWrapper
48
- _original_cls = ExaoneForCausalLM
50
+ _hf_class = AutoModelForCausalLM
49
51
 
50
52
  @classmethod
51
53
  def from_pretrained(cls, *args, **kwargs):
@@ -21,113 +21,42 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Dict, List, Optional, Tuple, Union
25
-
26
- import torch
27
- from transformers.modeling_outputs import (
28
- BaseModelOutputWithPast,
29
- )
30
-
31
- from ...models.decoderonly import (
32
- DecoderOnlyDecoderLayer,
24
+ from typing import TYPE_CHECKING
25
+
26
+ from ...models.decoderonly.decoderonly_architecture import (
27
+ DecoderOnlyAttention,
28
+ DecoderOnlyFlashAttention,
29
+ DecoderOnlyForCausalLM,
30
+ DecoderOnlyLayer,
31
+ DecoderOnlyModel,
33
32
  DecoderOnlyWrapper,
34
- slice_and_unsqueeze_cos_sin,
35
33
  )
36
- from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
37
-
38
-
39
- class GemmaWrapper(DecoderOnlyWrapper):
40
- def get_forward_dict(self):
41
- forward_dict = {}
42
- forward_dict.update(
43
- {
44
- "wrapper": GemmaModel.forward,
45
- "model": DecoderOnlyDecoderLayer.forward,
46
- "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
47
- }
48
- )
49
- return forward_dict
50
-
51
-
52
- class GemmaModel:
53
- def forward(
54
- self,
55
- input_ids: torch.LongTensor = None,
56
- attention_mask: Optional[torch.Tensor] = None,
57
- position_ids: Optional[torch.LongTensor] = None,
58
- past_key_values: Optional[List[torch.FloatTensor]] = None,
59
- batch_ids: Optional[torch.LongTensor] = None,
60
- inputs_embeds: Optional[torch.FloatTensor] = None,
61
- use_cache: Optional[bool] = True,
62
- output_attentions: Optional[bool] = False,
63
- output_hidden_states: Optional[bool] = False,
64
- cache_pos_for_partitions: Optional[torch.Tensor] = None,
65
- kvcache_partition_size: Optional[torch.Tensor] = None,
66
- forward_dict: Optional[Dict[str, classmethod]] = None,
67
- rotary_pos_emb=None,
68
- ) -> Union[Tuple, BaseModelOutputWithPast]:
69
- # retrieve input_ids and inputs_embeds
70
- if (input_ids is None) ^ (inputs_embeds is not None):
71
- raise ValueError(
72
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
73
- )
74
-
75
- # embed positions
76
- inputs_embeds = self.embed_tokens(input_ids)
77
- hidden_states = inputs_embeds
78
34
 
79
- ##### GEMMA change from llama#####
80
- hidden_states = hidden_states * (self.config.hidden_size**0.5)
81
35
 
82
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
36
+ if TYPE_CHECKING:
37
+ from transformers import GemmaForCausalLM
83
38
 
84
- # get cos,sin vector
85
- cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
86
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
87
39
 
88
- # decoder layers
89
- all_hidden_states = () if output_hidden_states else None
90
- all_self_attns = () if output_attentions else None
91
-
92
- for layer_idx, decoder_layer in enumerate(self.layers):
93
- if output_hidden_states:
94
- all_hidden_states += (hidden_states,)
95
- layer_outputs = forward_dict["model"](
96
- decoder_layer,
97
- hidden_states,
98
- layer_idx,
99
- attention_mask=attention_mask,
100
- position_ids=position_ids,
101
- past_key_value=past_key_values,
102
- output_attentions=output_attentions,
103
- use_cache=use_cache,
104
- batch_ids=batch_ids,
105
- cos=cos,
106
- sin=sin,
107
- cache_pos_for_partitions=cache_pos_for_partitions,
108
- kvcache_partition_size=kvcache_partition_size,
109
- forward_dict=forward_dict,
110
- )
111
-
112
- hidden_states = layer_outputs[0]
113
-
114
- updated_cache = layer_outputs[2 if output_attentions else 1]
115
-
116
- if output_attentions:
117
- all_self_attns += (layer_outputs[1],)
118
-
119
- hidden_states = self.norm(hidden_states)
120
-
121
- # add hidden states from the last decoder layer
122
- if output_hidden_states:
123
- all_hidden_states += (hidden_states,)
124
-
125
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
126
- next_cache = updated_cache.to_legacy_cache()
127
-
128
- return BaseModelOutputWithPast(
129
- last_hidden_state=hidden_states,
130
- past_key_values=next_cache,
131
- hidden_states=all_hidden_states,
132
- attentions=all_self_attns,
133
- )
40
+ class GemmaWrapper(DecoderOnlyWrapper):
41
+ def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
42
+ new_layers = []
43
+ for layer in causal_lm.model.layers:
44
+ if self.attn_impl == "eager":
45
+ new_self_attn = DecoderOnlyAttention(layer.self_attn)
46
+ elif self.attn_impl == "flash_attn":
47
+ new_self_attn = DecoderOnlyFlashAttention(
48
+ layer.self_attn, kvcache_partition_len=self.kvcache_partition_len
49
+ )
50
+ else:
51
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
52
+ new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
+ new_layers.append(new_layer)
54
+ new_model = GemmaModel(causal_lm.model, new_layers)
55
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
56
+ return new_causal_lm
57
+
58
+
59
+ class GemmaModel(DecoderOnlyModel):
60
+ @property
61
+ def hidden_multiplier(self):
62
+ return self._original_mod.config.hidden_size**0.5