optimum-rbln 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +3 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  4. optimum/rbln/diffusers/models/controlnet.py +4 -3
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
  7. optimum/rbln/modeling_alias.py +5 -1
  8. optimum/rbln/modeling_base.py +53 -19
  9. optimum/rbln/transformers/__init__.py +3 -1
  10. optimum/rbln/transformers/models/__init__.py +1 -0
  11. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +4 -3
  13. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +137 -22
  14. optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
  16. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
  17. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
  18. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
  19. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  20. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  21. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  22. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  23. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +8 -2
  24. optimum/rbln/transformers/utils/__init__.py +0 -0
  25. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  26. optimum/rbln/utils/import_utils.py +1 -4
  27. optimum/rbln/utils/runtime_utils.py +2 -1
  28. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +10 -3
  29. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +31 -26
  30. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  31. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -32,6 +32,7 @@ _import_structure = {
32
32
  "modeling_alias": [
33
33
  "RBLNASTForAudioClassification",
34
34
  "RBLNBertForQuestionAnswering",
35
+ "RBLNDistilBertForQuestionAnswering",
35
36
  "RBLNResNetForImageClassification",
36
37
  "RBLNT5ForConditionalGeneration",
37
38
  "RBLNBartForConditionalGeneration",
@@ -61,6 +62,7 @@ _import_structure = {
61
62
  "RBLNWav2Vec2ForCTC",
62
63
  "RBLNLlamaForCausalLM",
63
64
  "RBLNMidmLMHeadModel",
65
+ "RBLNMistralForCausalLM",
64
66
  "RBLNWhisperForConditionalGeneration",
65
67
  "RBLNXLMRobertaModel",
66
68
  ],
@@ -126,6 +128,7 @@ if TYPE_CHECKING:
126
128
  RBLNGPT2LMHeadModel,
127
129
  RBLNLlamaForCausalLM,
128
130
  RBLNMidmLMHeadModel,
131
+ RBLNMistralForCausalLM,
129
132
  RBLNWav2Vec2ForCTC,
130
133
  RBLNWhisperForConditionalGeneration,
131
134
  RBLNXLMRobertaModel,
@@ -1 +1 @@
1
- __version__ = '0.1.8'
1
+ __version__ = '0.1.9'
@@ -26,7 +26,7 @@ from pathlib import Path
26
26
  from typing import TYPE_CHECKING, Dict, List, Optional, Union
27
27
 
28
28
  import rebel
29
- import torch
29
+ import torch # noqa: I001
30
30
  from diffusers import AutoencoderKL
31
31
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
32
32
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
@@ -38,12 +38,12 @@ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRunt
38
38
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
39
 
40
40
 
41
- logger = logging.getLogger(__name__)
42
-
43
41
  if TYPE_CHECKING:
44
42
  import torch
45
43
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
46
44
 
45
+ logger = logging.getLogger(__name__)
46
+
47
47
 
48
48
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
49
49
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
@@ -34,12 +34,13 @@ from ...modeling_base import RBLNModel
34
34
  from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
35
35
 
36
36
 
37
- logger = logging.getLogger(__name__)
38
-
39
37
  if TYPE_CHECKING:
40
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
41
39
 
42
40
 
41
+ logger = logging.getLogger(__name__)
42
+
43
+
43
44
  class _ControlNetModel(torch.nn.Module):
44
45
  def __init__(self, controlnet: "ControlNetModel"):
45
46
  super().__init__()
@@ -138,7 +139,7 @@ class RBLNControlNetModel(RBLNModel):
138
139
  return rt
139
140
 
140
141
  @classmethod
141
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
142
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
142
143
  use_encoder_hidden_states = False
143
144
  for down_block in model.down_blocks:
144
145
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -35,11 +35,11 @@ from ...modeling_base import RBLNModel
35
35
  from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
36
36
 
37
37
 
38
- logger = logging.getLogger(__name__)
39
-
40
38
  if TYPE_CHECKING:
41
39
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
40
 
41
+ logger = logging.getLogger(__name__)
42
+
43
43
 
44
44
  class _UNet_SD(torch.nn.Module):
45
45
  def __init__(self, unet: "UNet2DConditionModel"):
@@ -172,7 +172,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
172
172
  return rt
173
173
 
174
174
  @classmethod
175
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
175
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
176
176
  if model.config.addition_embed_type == "text_time":
177
177
  return _UNet_SDXL(model).eval()
178
178
  else:
@@ -37,11 +37,11 @@ from ....modeling_config import RBLNConfig
37
37
  from ...models.controlnet import RBLNControlNetModel
38
38
 
39
39
 
40
- logger = logging.getLogger(__name__)
41
-
42
40
  if TYPE_CHECKING:
43
41
  pass
44
42
 
43
+ logger = logging.getLogger(__name__)
44
+
45
45
 
46
46
  class RBLNMultiControlNetModel(RBLNModel):
47
47
  def __init__(
@@ -79,7 +79,6 @@ class RBLNMultiControlNetModel(RBLNModel):
79
79
  model_id: Union[str, Path],
80
80
  **kwargs,
81
81
  ) -> RBLNModel:
82
-
83
82
  idx = 0
84
83
  controlnets = []
85
84
  model_path_to_load = model_id
@@ -36,7 +36,11 @@ class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
36
36
 
37
37
 
38
38
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
39
- pass
39
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
40
+
41
+
42
+ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
43
+ rbln_model_input_names = ["input_ids", "attention_mask"]
40
44
 
41
45
 
42
46
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
@@ -51,10 +51,15 @@ from .utils.runtime_utils import UnavailableRuntime
51
51
  from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
52
52
 
53
53
 
54
- logger = logging.getLogger(__name__)
55
-
56
54
  if TYPE_CHECKING:
57
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
55
+ from transformers import (
56
+ AutoFeatureExtractor,
57
+ AutoProcessor,
58
+ AutoTokenizer,
59
+ PreTrainedModel,
60
+ )
61
+
62
+ logger = logging.getLogger(__name__)
58
63
 
59
64
 
60
65
  class RBLNBaseModel(OptimizedModel, ABC):
@@ -156,13 +161,23 @@ class RBLNBaseModel(OptimizedModel, ABC):
156
161
  Directory where to save the model file.
157
162
  """
158
163
  real_save_dir = self.model_save_dir / self.subfolder
164
+ save_directory_path = Path(save_directory)
159
165
  if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
166
+ if save_directory_path.absolute() == real_save_dir.absolute():
167
+ raise FileExistsError(
168
+ f"Cannot save model to '{save_directory}'. "
169
+ f"This directory already exists and contains the model files."
170
+ )
160
171
  shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
161
172
  self.config.save_pretrained(save_directory)
162
173
  if self.generation_config is not None:
163
174
  self.generation_config.save_pretrained(save_directory)
164
175
  else:
165
- raise FileNotFoundError(f"Saving compiled model failed.({real_save_dir}).")
176
+ raise FileNotFoundError(
177
+ f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
178
+ f"Cannot save to the specified destination '{save_directory}'. "
179
+ f"Please ensure the model directory exists and you have the necessary permissions to access it."
180
+ )
166
181
 
167
182
  @classmethod
168
183
  def _from_pretrained(
@@ -196,7 +211,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
196
211
  token = HfFolder().get_token()
197
212
  else:
198
213
  token = use_auth_token
199
- repo_files = list(map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)))
214
+ repo_files = list(
215
+ map(
216
+ Path,
217
+ HfApi().list_repo_files(model_id, revision=revision, token=token),
218
+ )
219
+ )
200
220
 
201
221
  pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
202
222
  rbln_files = [p for p in repo_files if p.match(pattern)]
@@ -287,7 +307,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
287
307
  preprocessors,
288
308
  model_save_dir=model_save_dir,
289
309
  subfolder=subfolder,
290
- rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
310
+ rbln_compiled_models=(None if rbln_optimize_host_memory else rbln_compiled_models),
291
311
  **kwargs,
292
312
  )
293
313
 
@@ -377,7 +397,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
377
397
  return self.forward(*args, **kwargs)
378
398
 
379
399
  @classmethod
380
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
400
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
381
401
  # Wrap the model if needed.
382
402
  return model
383
403
 
@@ -400,7 +420,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
400
420
  @classmethod
401
421
  @abstractmethod
402
422
  def _create_runtimes(
403
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
423
+ cls,
424
+ compiled_models: List[rebel.RBLNCompiledModel],
425
+ rbln_device_map: Dict[str, int],
404
426
  ) -> List[rebel.Runtime]:
405
427
  # compiled_models -> runtimes
406
428
  pass
@@ -497,7 +519,7 @@ class RBLNModel(RBLNBaseModel):
497
519
 
498
520
  @classmethod
499
521
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
500
- model = cls.wrap_model_if_needed(model)
522
+ model = cls.wrap_model_if_needed(model, rbln_config)
501
523
  rbln_runtime_configs = list(rbln_config.values())
502
524
  if len(rbln_runtime_configs) != 1:
503
525
  raise ValueError
@@ -598,7 +620,9 @@ class RBLNModel(RBLNBaseModel):
598
620
 
599
621
  @classmethod
600
622
  def _create_runtimes(
601
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
623
+ cls,
624
+ compiled_models: List[rebel.RBLNCompiledModel],
625
+ rbln_device_map: Dict[str, int],
602
626
  ) -> List[rebel.Runtime]:
603
627
  device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
604
628
  return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
@@ -618,8 +642,8 @@ class RBLNModelForQuestionAnswering(RBLNModel):
618
642
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
619
643
  model_config: Optional["PretrainedConfig"] = None,
620
644
  rbln_max_seq_len: Optional[int] = None,
621
- rbln_model_input_names: Optional[List[str]] = None,
622
645
  rbln_batch_size: Optional[int] = None,
646
+ rbln_model_input_names: Optional[List[str]] = None,
623
647
  ) -> RBLNConfig:
624
648
  if rbln_max_seq_len is None:
625
649
  for tokenizer in preprocessors:
@@ -629,15 +653,15 @@ class RBLNModelForQuestionAnswering(RBLNModel):
629
653
  if rbln_max_seq_len is None:
630
654
  raise ValueError("`rbln_max_seq_len` should be specified!")
631
655
 
632
- if rbln_model_input_names is None:
633
- # These are BERT's inputs
634
- rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
635
-
636
656
  if rbln_batch_size is None:
637
657
  rbln_batch_size = 1
658
+
659
+ if rbln_model_input_names is not None:
660
+ cls.rbln_model_input_names = rbln_model_input_names
661
+
638
662
  input_info = [
639
663
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
640
- for model_input_name in rbln_model_input_names
664
+ for model_input_name in cls.rbln_model_input_names
641
665
  ]
642
666
 
643
667
  rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
@@ -674,7 +698,13 @@ class RBLNModelForImageClassification(RBLNModel):
674
698
  if rbln_batch_size is None:
675
699
  rbln_batch_size = 1
676
700
 
677
- input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size, rbln_image_size], "float32")]
701
+ input_info = [
702
+ (
703
+ "pixel_values",
704
+ [rbln_batch_size, 3, rbln_image_size, rbln_image_size],
705
+ "float32",
706
+ )
707
+ ]
678
708
 
679
709
  rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
680
710
  rbln_runtime_config.batch_size = rbln_batch_size
@@ -739,7 +769,11 @@ class RBLNModelForAudioClassification(RBLNModel):
739
769
  meta["rbln_num_mel_bins"] = rbln_num_mel_bins
740
770
 
741
771
  model_input_info = [
742
- ("input_values", [rbln_batch_size, rbln_max_length, rbln_num_mel_bins], "float32"),
772
+ (
773
+ "input_values",
774
+ [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
775
+ "float32",
776
+ ),
743
777
  ]
744
778
 
745
779
  rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
@@ -777,7 +811,6 @@ class RBLNModelForSequenceClassification(RBLNModel):
777
811
  rbln_model_input_names: Optional[List[str]] = None,
778
812
  rbln_batch_size: Optional[int] = None,
779
813
  ) -> RBLNConfig:
780
-
781
814
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
782
815
  model_config, "max_position_embeddings", None
783
816
  )
@@ -812,6 +845,7 @@ class RBLNModelForSequenceClassification(RBLNModel):
812
845
 
813
846
  return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
814
847
 
848
+
815
849
  class RBLNModelForMaskedLM(RBLNModel):
816
850
  model_type = "rbln_model"
817
851
  auto_model_class = AutoModelForMaskedLM
@@ -39,7 +39,8 @@ _import_structure = {
39
39
  "RBLNWhisperForConditionalGeneration",
40
40
  "RBLNLlamaForCausalLM",
41
41
  "RBLNMidmLMHeadModel",
42
- "RBLNXLMRobertaModel"
42
+ "RBLNMistralForCausalLM",
43
+ "RBLNXLMRobertaModel",
43
44
  ],
44
45
  }
45
46
 
@@ -54,6 +55,7 @@ if TYPE_CHECKING:
54
55
  RBLNGPT2LMHeadModel,
55
56
  RBLNLlamaForCausalLM,
56
57
  RBLNMidmLMHeadModel,
58
+ RBLNMistralForCausalLM,
57
59
  RBLNWav2Vec2ForCTC,
58
60
  RBLNWhisperForConditionalGeneration,
59
61
  RBLNXLMRobertaModel,
@@ -27,6 +27,7 @@ from .gemma import RBLNGemmaForCausalLM
27
27
  from .gpt2 import RBLNGPT2LMHeadModel
28
28
  from .llama import RBLNLlamaForCausalLM
29
29
  from .midm import RBLNMidmLMHeadModel
30
+ from .mistral import RBLNMistralForCausalLM
30
31
  from .wav2vec2 import RBLNWav2Vec2ForCTC
31
32
  from .whisper import RBLNWhisperForConditionalGeneration
32
33
  from .xlm_roberta import RBLNXLMRobertaModel
@@ -70,7 +70,7 @@ class RBLNCLIPTextModel(RBLNModel):
70
70
  return rt
71
71
 
72
72
  @classmethod
73
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
73
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
74
74
  return _TextEncoder(model).eval()
75
75
 
76
76
  @classmethod
@@ -49,18 +49,19 @@ class DecoderOnlyWrapper(torch.nn.Module):
49
49
  self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
50
50
  )
51
51
  self.max_seq_len = max_seq_len
52
+ self.rope_scaling = getattr(self.config, "rope_scaling", None)
52
53
  self.rotary_emb = self._init_rope()
53
54
 
54
55
  def _init_rope(self):
55
- if self.config.rope_scaling is None:
56
+ if self.rope_scaling is None:
56
57
  rotary_emb = RotaryEmbedding(
57
58
  self.head_dim,
58
59
  max_position_embeddings=self.max_position_embeddings,
59
60
  base=self.config.rope_theta,
60
61
  )
61
62
  else:
62
- scaling_type = self.config.rope_scaling["type"]
63
- scaling_factor = self.config.rope_scaling["factor"]
63
+ scaling_type = self.rope_scaling["type"]
64
+ scaling_factor = self.rope_scaling["factor"]
64
65
  if scaling_type == "linear":
65
66
  rotary_emb = LinearScalingRotaryEmbedding(
66
67
  self.head_dim,
@@ -20,18 +20,22 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import glob
23
24
  import logging
24
- from abc import ABC, abstractmethod
25
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
25
+ from abc import ABC
26
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
27
 
27
28
  import rebel # noqa: F401
28
29
  import torch # noqa: F401
29
- from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
30
+ from safetensors.torch import load_file
31
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
30
32
  from transformers.modeling_outputs import CausalLMOutputWithPast
33
+ from transformers.modeling_utils import no_init_weights
31
34
 
32
35
  from ....modeling_base import RBLNModel
33
36
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
34
37
  from ....utils.runtime_utils import RBLNPytorchRuntime
38
+ from ...utils.rbln_quantization import replace_quantized_linear_layers
35
39
 
36
40
 
37
41
  logger = logging.getLogger(__name__)
@@ -44,6 +48,12 @@ if TYPE_CHECKING:
44
48
  PretrainedConfig,
45
49
  )
46
50
 
51
+ SUPPORTED_QUANTIZATIONS = {
52
+ "rbln": [
53
+ "w4a16",
54
+ ],
55
+ }
56
+
47
57
 
48
58
  class RBLNRuntimeModel(RBLNPytorchRuntime):
49
59
  mandatory_members = ["main_input_name"]
@@ -78,26 +88,98 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
78
88
  self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
79
89
 
80
90
  @classmethod
81
- @abstractmethod
82
- def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
83
- pass
91
+ def get_quantized_model(
92
+ cls,
93
+ model_id: str,
94
+ use_auth_token: Optional[Union[bool, str]] = None,
95
+ revision: Optional[str] = None,
96
+ force_download: bool = False,
97
+ cache_dir: Optional[str] = None,
98
+ subfolder: str = "",
99
+ local_files_only: bool = False,
100
+ trust_remote_code: bool = False,
101
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
102
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
103
+ **kwargs,
104
+ ):
105
+ kwargs = cls.update_kwargs(kwargs)
106
+
107
+ config = AutoConfig.from_pretrained(
108
+ model_id,
109
+ use_auth_token=use_auth_token,
110
+ revision=revision,
111
+ force_download=force_download,
112
+ cache_dir=cache_dir,
113
+ trust_remote_code=trust_remote_code,
114
+ **kwargs,
115
+ )
116
+
117
+ with no_init_weights():
118
+ model = AutoModelForCausalLM.from_config(config)
119
+ replace_quantized_linear_layers(model)
120
+
121
+ state_dict = {}
122
+ for safetensor_file in glob.glob(f"{model_id}/*.safetensors"):
123
+ partial_state_dict = load_file(safetensor_file)
124
+ state_dict.update(partial_state_dict)
125
+
126
+ n_layer = kwargs.get("num_hidden_layers", None)
127
+ if n_layer is not None:
128
+ keys_to_delete = []
129
+ for key in state_dict.keys():
130
+ parts = key.split(".")
131
+ if len(parts) > 2 and parts[2].isdigit():
132
+ layer_num = int(parts[2])
133
+ if layer_num >= n_layer:
134
+ keys_to_delete.append(key)
135
+
136
+ for key in keys_to_delete:
137
+ del state_dict[key]
138
+
139
+ model.load_state_dict(state_dict)
140
+ return model
141
+
142
+ @classmethod
143
+ def get_pytorch_model(
144
+ cls,
145
+ *args,
146
+ **kwargs,
147
+ ) -> "PreTrainedModel":
148
+ rbln_config_kwargs = kwargs.get("rbln_config_kwargs", {})
149
+ rbln_quantization = rbln_config_kwargs.get("rbln_quantization", None)
150
+
151
+ if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
152
+ model = cls.get_quantized_model(*args, **kwargs)
153
+ else:
154
+ model = super().get_pytorch_model(*args, **kwargs)
155
+
156
+ return model
84
157
 
85
158
  @classmethod
86
159
  @torch.inference_mode()
87
160
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
88
- wrapped_model = cls.wrapping_torch_model(model, rbln_config.meta["rbln_max_seq_len"])
161
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
89
162
 
90
163
  prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
91
164
  dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
92
165
 
93
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
94
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
166
+ def get_scripted_model():
167
+ # This function is nested to dealloc the example inputs before compilation.
168
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
169
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
170
+
171
+ batch_index = 3
172
+ dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
95
173
 
96
- batch_index = 3
97
- dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
174
+ prefill_scripted_model = torch.jit.trace(
175
+ wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
176
+ )
177
+ dec_scripted_model = torch.jit.trace(
178
+ wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
179
+ )
180
+ return prefill_scripted_model, dec_scripted_model
98
181
 
99
- prefill_scripted_model = torch.jit.trace(wrapped_model, prefill_example_inputs, check_trace=False)
100
- dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
182
+ prefill_scripted_model, dec_scripted_model = get_scripted_model()
101
183
 
102
184
  prefill_ir = rebel.torchscript_to_ir(
103
185
  prefill_scripted_model,
@@ -133,28 +215,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
133
215
  model_config: "PretrainedConfig",
134
216
  rbln_max_seq_len: Optional[int] = None,
135
217
  rbln_batch_size: Optional[int] = None,
218
+ rbln_quantization: Optional[Dict[str, str]] = None,
136
219
  **kwargs,
137
220
  ) -> RBLNConfig:
138
221
  meta = {}
139
222
 
140
223
  prefill_chunk_size = 128
141
224
  if rbln_max_seq_len is None:
142
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
225
+ rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
226
+ model_config, "n_positions", None
227
+ )
228
+ if rbln_max_seq_len is None:
229
+ raise ValueError("`rbln_max_seq_len` should be specified.")
143
230
  rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
144
231
 
145
232
  meta["rbln_max_seq_len"] = rbln_max_seq_len
146
233
  meta["rbln_batch_size"] = rbln_batch_size
147
234
  meta["rbln_prefill_chunk_size"] = prefill_chunk_size
148
235
 
236
+ num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
237
+ num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
238
+ num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
239
+ head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
240
+
241
+ if rbln_quantization is not None:
242
+ q_format = rbln_quantization.get("format", None)
243
+ q_precision = rbln_quantization.get("precision", None)
244
+
245
+ if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
246
+ raise ValueError(
247
+ f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
248
+ f"Possible: {SUPPORTED_QUANTIZATIONS}"
249
+ )
250
+ meta["rbln_quantization"] = rbln_quantization
251
+
149
252
  def get_input_info(
150
253
  batch_size,
151
254
  query_length,
152
255
  ):
153
- head_dim = (
154
- model_config.head_dim
155
- if hasattr(model_config, "head_dim")
156
- else model_config.hidden_size // model_config.num_attention_heads
157
- )
158
256
  input_info = [
159
257
  ("input_ids", [batch_size, query_length], "int64"),
160
258
  ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
@@ -172,13 +270,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
172
270
  f"past_key_values_{i}",
173
271
  [
174
272
  rbln_batch_size,
175
- model_config.num_key_value_heads,
273
+ num_key_value_heads,
176
274
  rbln_max_seq_len,
177
275
  head_dim,
178
276
  ],
179
277
  "float32",
180
278
  )
181
- for i in range(model_config.num_hidden_layers * 2)
279
+ for i in range(num_hidden_layers * 2)
182
280
  ]
183
281
  )
184
282
 
@@ -295,6 +393,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
295
393
  raise RuntimeError(
296
394
  f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
297
395
  )
396
+
397
+ out_buffers = [
398
+ torch.empty(
399
+ size=[
400
+ 1,
401
+ self.prefill_chunk_size,
402
+ self.config.vocab_size,
403
+ ],
404
+ dtype=torch.float32,
405
+ device="cpu",
406
+ ),
407
+ torch.empty(size=[], dtype=torch.int16, device="cpu"),
408
+ ]
409
+
298
410
  query_length = input_ids.shape[1]
299
411
  attention_mask = self.prefill_attention_mask.clone()
300
412
  for step in range(0, query_length, self.prefill_chunk_size):
@@ -314,7 +426,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
314
426
 
315
427
  sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
316
428
  sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
317
- attention_mask[:, :, :, :step] = 1
429
+
430
+ if step >= self.prefill_chunk_size:
431
+ attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
318
432
  attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
319
433
 
320
434
  logits, _ = self.prefill_decoder(
@@ -322,6 +436,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
322
436
  attention_mask.contiguous(),
323
437
  sliced_cache_positions.contiguous(),
324
438
  torch.tensor(batch_idx, dtype=torch.int16),
439
+ out=out_buffers,
325
440
  )
326
441
  logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
327
442
 
@@ -39,9 +39,16 @@ from ...models.decoderonly import (
39
39
  class GemmaWrapper(DecoderOnlyWrapper):
40
40
  def get_forward_dict(self):
41
41
  forward_dict = {}
42
- forward_dict.update({"wrapper": GemmaModel.forward, "model": DecoderOnlyDecoderLayer.forward, "decoder_layer": DecoderOnlyAttention.forward,})
42
+ forward_dict.update(
43
+ {
44
+ "wrapper": GemmaModel.forward,
45
+ "model": DecoderOnlyDecoderLayer.forward,
46
+ "decoder_layer": DecoderOnlyAttention.forward,
47
+ }
48
+ )
43
49
  return forward_dict
44
50
 
51
+
45
52
  class GemmaModel:
46
53
  def forward(
47
54
  self,
@@ -54,7 +61,7 @@ class GemmaModel:
54
61
  use_cache: Optional[bool] = True,
55
62
  output_attentions: Optional[bool] = False,
56
63
  output_hidden_states: Optional[bool] = False,
57
- forward_dict : Optional[Dict[str, classmethod]] = None,
64
+ forward_dict: Optional[Dict[str, classmethod]] = None,
58
65
  rotary_pos_emb=None,
59
66
  ) -> Union[Tuple, BaseModelOutputWithPast]:
60
67
  # embed positions
@@ -89,7 +96,7 @@ class GemmaModel:
89
96
  batch_ids=batch_ids,
90
97
  cos=cos,
91
98
  sin=sin,
92
- forward_dict=forward_dict
99
+ forward_dict=forward_dict,
93
100
  )
94
101
 
95
102
  hidden_states = layer_outputs[0]