optimum-rbln 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +3 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +4 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
- optimum/rbln/modeling_alias.py +5 -1
- optimum/rbln/modeling_base.py +53 -19
- optimum/rbln/transformers/__init__.py +3 -1
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +4 -3
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +137 -22
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +8 -2
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +10 -3
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +31 -26
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
@@ -32,6 +32,7 @@ _import_structure = {
|
|
32
32
|
"modeling_alias": [
|
33
33
|
"RBLNASTForAudioClassification",
|
34
34
|
"RBLNBertForQuestionAnswering",
|
35
|
+
"RBLNDistilBertForQuestionAnswering",
|
35
36
|
"RBLNResNetForImageClassification",
|
36
37
|
"RBLNT5ForConditionalGeneration",
|
37
38
|
"RBLNBartForConditionalGeneration",
|
@@ -61,6 +62,7 @@ _import_structure = {
|
|
61
62
|
"RBLNWav2Vec2ForCTC",
|
62
63
|
"RBLNLlamaForCausalLM",
|
63
64
|
"RBLNMidmLMHeadModel",
|
65
|
+
"RBLNMistralForCausalLM",
|
64
66
|
"RBLNWhisperForConditionalGeneration",
|
65
67
|
"RBLNXLMRobertaModel",
|
66
68
|
],
|
@@ -126,6 +128,7 @@ if TYPE_CHECKING:
|
|
126
128
|
RBLNGPT2LMHeadModel,
|
127
129
|
RBLNLlamaForCausalLM,
|
128
130
|
RBLNMidmLMHeadModel,
|
131
|
+
RBLNMistralForCausalLM,
|
129
132
|
RBLNWav2Vec2ForCTC,
|
130
133
|
RBLNWhisperForConditionalGeneration,
|
131
134
|
RBLNXLMRobertaModel,
|
optimum/rbln/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.1.
|
1
|
+
__version__ = '0.1.9'
|
@@ -26,7 +26,7 @@ from pathlib import Path
|
|
26
26
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
27
27
|
|
28
28
|
import rebel
|
29
|
-
import torch
|
29
|
+
import torch # noqa: I001
|
30
30
|
from diffusers import AutoencoderKL
|
31
31
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
32
32
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
@@ -38,12 +38,12 @@ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRunt
|
|
38
38
|
from ...utils.runtime_utils import RBLNPytorchRuntime
|
39
39
|
|
40
40
|
|
41
|
-
logger = logging.getLogger(__name__)
|
42
|
-
|
43
41
|
if TYPE_CHECKING:
|
44
42
|
import torch
|
45
43
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
46
44
|
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
47
|
|
48
48
|
class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
|
49
49
|
def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
@@ -34,12 +34,13 @@ from ...modeling_base import RBLNModel
|
|
34
34
|
from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
|
35
35
|
|
36
36
|
|
37
|
-
logger = logging.getLogger(__name__)
|
38
|
-
|
39
37
|
if TYPE_CHECKING:
|
40
38
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
41
39
|
|
42
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
43
44
|
class _ControlNetModel(torch.nn.Module):
|
44
45
|
def __init__(self, controlnet: "ControlNetModel"):
|
45
46
|
super().__init__()
|
@@ -138,7 +139,7 @@ class RBLNControlNetModel(RBLNModel):
|
|
138
139
|
return rt
|
139
140
|
|
140
141
|
@classmethod
|
141
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
142
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
142
143
|
use_encoder_hidden_states = False
|
143
144
|
for down_block in model.down_blocks:
|
144
145
|
if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
|
@@ -35,11 +35,11 @@ from ...modeling_base import RBLNModel
|
|
35
35
|
from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
|
36
36
|
|
37
37
|
|
38
|
-
logger = logging.getLogger(__name__)
|
39
|
-
|
40
38
|
if TYPE_CHECKING:
|
41
39
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
42
40
|
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
43
|
|
44
44
|
class _UNet_SD(torch.nn.Module):
|
45
45
|
def __init__(self, unet: "UNet2DConditionModel"):
|
@@ -172,7 +172,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
|
|
172
172
|
return rt
|
173
173
|
|
174
174
|
@classmethod
|
175
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
175
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
176
176
|
if model.config.addition_embed_type == "text_time":
|
177
177
|
return _UNet_SDXL(model).eval()
|
178
178
|
else:
|
@@ -37,11 +37,11 @@ from ....modeling_config import RBLNConfig
|
|
37
37
|
from ...models.controlnet import RBLNControlNetModel
|
38
38
|
|
39
39
|
|
40
|
-
logger = logging.getLogger(__name__)
|
41
|
-
|
42
40
|
if TYPE_CHECKING:
|
43
41
|
pass
|
44
42
|
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
45
|
|
46
46
|
class RBLNMultiControlNetModel(RBLNModel):
|
47
47
|
def __init__(
|
@@ -79,7 +79,6 @@ class RBLNMultiControlNetModel(RBLNModel):
|
|
79
79
|
model_id: Union[str, Path],
|
80
80
|
**kwargs,
|
81
81
|
) -> RBLNModel:
|
82
|
-
|
83
82
|
idx = 0
|
84
83
|
controlnets = []
|
85
84
|
model_path_to_load = model_id
|
optimum/rbln/modeling_alias.py
CHANGED
@@ -36,7 +36,11 @@ class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
|
|
36
36
|
|
37
37
|
|
38
38
|
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
39
|
-
|
39
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
40
|
+
|
41
|
+
|
42
|
+
class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
43
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
40
44
|
|
41
45
|
|
42
46
|
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
optimum/rbln/modeling_base.py
CHANGED
@@ -51,10 +51,15 @@ from .utils.runtime_utils import UnavailableRuntime
|
|
51
51
|
from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
|
52
52
|
|
53
53
|
|
54
|
-
logger = logging.getLogger(__name__)
|
55
|
-
|
56
54
|
if TYPE_CHECKING:
|
57
|
-
from transformers import
|
55
|
+
from transformers import (
|
56
|
+
AutoFeatureExtractor,
|
57
|
+
AutoProcessor,
|
58
|
+
AutoTokenizer,
|
59
|
+
PreTrainedModel,
|
60
|
+
)
|
61
|
+
|
62
|
+
logger = logging.getLogger(__name__)
|
58
63
|
|
59
64
|
|
60
65
|
class RBLNBaseModel(OptimizedModel, ABC):
|
@@ -156,13 +161,23 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
156
161
|
Directory where to save the model file.
|
157
162
|
"""
|
158
163
|
real_save_dir = self.model_save_dir / self.subfolder
|
164
|
+
save_directory_path = Path(save_directory)
|
159
165
|
if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
|
166
|
+
if save_directory_path.absolute() == real_save_dir.absolute():
|
167
|
+
raise FileExistsError(
|
168
|
+
f"Cannot save model to '{save_directory}'. "
|
169
|
+
f"This directory already exists and contains the model files."
|
170
|
+
)
|
160
171
|
shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
|
161
172
|
self.config.save_pretrained(save_directory)
|
162
173
|
if self.generation_config is not None:
|
163
174
|
self.generation_config.save_pretrained(save_directory)
|
164
175
|
else:
|
165
|
-
raise FileNotFoundError(
|
176
|
+
raise FileNotFoundError(
|
177
|
+
f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
|
178
|
+
f"Cannot save to the specified destination '{save_directory}'. "
|
179
|
+
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
180
|
+
)
|
166
181
|
|
167
182
|
@classmethod
|
168
183
|
def _from_pretrained(
|
@@ -196,7 +211,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
196
211
|
token = HfFolder().get_token()
|
197
212
|
else:
|
198
213
|
token = use_auth_token
|
199
|
-
repo_files = list(
|
214
|
+
repo_files = list(
|
215
|
+
map(
|
216
|
+
Path,
|
217
|
+
HfApi().list_repo_files(model_id, revision=revision, token=token),
|
218
|
+
)
|
219
|
+
)
|
200
220
|
|
201
221
|
pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
|
202
222
|
rbln_files = [p for p in repo_files if p.match(pattern)]
|
@@ -287,7 +307,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
287
307
|
preprocessors,
|
288
308
|
model_save_dir=model_save_dir,
|
289
309
|
subfolder=subfolder,
|
290
|
-
rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
|
310
|
+
rbln_compiled_models=(None if rbln_optimize_host_memory else rbln_compiled_models),
|
291
311
|
**kwargs,
|
292
312
|
)
|
293
313
|
|
@@ -377,7 +397,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
377
397
|
return self.forward(*args, **kwargs)
|
378
398
|
|
379
399
|
@classmethod
|
380
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
400
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
381
401
|
# Wrap the model if needed.
|
382
402
|
return model
|
383
403
|
|
@@ -400,7 +420,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
400
420
|
@classmethod
|
401
421
|
@abstractmethod
|
402
422
|
def _create_runtimes(
|
403
|
-
cls,
|
423
|
+
cls,
|
424
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
425
|
+
rbln_device_map: Dict[str, int],
|
404
426
|
) -> List[rebel.Runtime]:
|
405
427
|
# compiled_models -> runtimes
|
406
428
|
pass
|
@@ -497,7 +519,7 @@ class RBLNModel(RBLNBaseModel):
|
|
497
519
|
|
498
520
|
@classmethod
|
499
521
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
500
|
-
model = cls.wrap_model_if_needed(model)
|
522
|
+
model = cls.wrap_model_if_needed(model, rbln_config)
|
501
523
|
rbln_runtime_configs = list(rbln_config.values())
|
502
524
|
if len(rbln_runtime_configs) != 1:
|
503
525
|
raise ValueError
|
@@ -598,7 +620,9 @@ class RBLNModel(RBLNBaseModel):
|
|
598
620
|
|
599
621
|
@classmethod
|
600
622
|
def _create_runtimes(
|
601
|
-
cls,
|
623
|
+
cls,
|
624
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
625
|
+
rbln_device_map: Dict[str, int],
|
602
626
|
) -> List[rebel.Runtime]:
|
603
627
|
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
604
628
|
return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
|
@@ -618,8 +642,8 @@ class RBLNModelForQuestionAnswering(RBLNModel):
|
|
618
642
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
619
643
|
model_config: Optional["PretrainedConfig"] = None,
|
620
644
|
rbln_max_seq_len: Optional[int] = None,
|
621
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
622
645
|
rbln_batch_size: Optional[int] = None,
|
646
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
623
647
|
) -> RBLNConfig:
|
624
648
|
if rbln_max_seq_len is None:
|
625
649
|
for tokenizer in preprocessors:
|
@@ -629,15 +653,15 @@ class RBLNModelForQuestionAnswering(RBLNModel):
|
|
629
653
|
if rbln_max_seq_len is None:
|
630
654
|
raise ValueError("`rbln_max_seq_len` should be specified!")
|
631
655
|
|
632
|
-
if rbln_model_input_names is None:
|
633
|
-
# These are BERT's inputs
|
634
|
-
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
635
|
-
|
636
656
|
if rbln_batch_size is None:
|
637
657
|
rbln_batch_size = 1
|
658
|
+
|
659
|
+
if rbln_model_input_names is not None:
|
660
|
+
cls.rbln_model_input_names = rbln_model_input_names
|
661
|
+
|
638
662
|
input_info = [
|
639
663
|
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
640
|
-
for model_input_name in rbln_model_input_names
|
664
|
+
for model_input_name in cls.rbln_model_input_names
|
641
665
|
]
|
642
666
|
|
643
667
|
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
@@ -674,7 +698,13 @@ class RBLNModelForImageClassification(RBLNModel):
|
|
674
698
|
if rbln_batch_size is None:
|
675
699
|
rbln_batch_size = 1
|
676
700
|
|
677
|
-
input_info = [
|
701
|
+
input_info = [
|
702
|
+
(
|
703
|
+
"pixel_values",
|
704
|
+
[rbln_batch_size, 3, rbln_image_size, rbln_image_size],
|
705
|
+
"float32",
|
706
|
+
)
|
707
|
+
]
|
678
708
|
|
679
709
|
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
680
710
|
rbln_runtime_config.batch_size = rbln_batch_size
|
@@ -739,7 +769,11 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
739
769
|
meta["rbln_num_mel_bins"] = rbln_num_mel_bins
|
740
770
|
|
741
771
|
model_input_info = [
|
742
|
-
(
|
772
|
+
(
|
773
|
+
"input_values",
|
774
|
+
[rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
|
775
|
+
"float32",
|
776
|
+
),
|
743
777
|
]
|
744
778
|
|
745
779
|
rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
|
@@ -777,7 +811,6 @@ class RBLNModelForSequenceClassification(RBLNModel):
|
|
777
811
|
rbln_model_input_names: Optional[List[str]] = None,
|
778
812
|
rbln_batch_size: Optional[int] = None,
|
779
813
|
) -> RBLNConfig:
|
780
|
-
|
781
814
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
782
815
|
model_config, "max_position_embeddings", None
|
783
816
|
)
|
@@ -812,6 +845,7 @@ class RBLNModelForSequenceClassification(RBLNModel):
|
|
812
845
|
|
813
846
|
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
814
847
|
|
848
|
+
|
815
849
|
class RBLNModelForMaskedLM(RBLNModel):
|
816
850
|
model_type = "rbln_model"
|
817
851
|
auto_model_class = AutoModelForMaskedLM
|
@@ -39,7 +39,8 @@ _import_structure = {
|
|
39
39
|
"RBLNWhisperForConditionalGeneration",
|
40
40
|
"RBLNLlamaForCausalLM",
|
41
41
|
"RBLNMidmLMHeadModel",
|
42
|
-
"
|
42
|
+
"RBLNMistralForCausalLM",
|
43
|
+
"RBLNXLMRobertaModel",
|
43
44
|
],
|
44
45
|
}
|
45
46
|
|
@@ -54,6 +55,7 @@ if TYPE_CHECKING:
|
|
54
55
|
RBLNGPT2LMHeadModel,
|
55
56
|
RBLNLlamaForCausalLM,
|
56
57
|
RBLNMidmLMHeadModel,
|
58
|
+
RBLNMistralForCausalLM,
|
57
59
|
RBLNWav2Vec2ForCTC,
|
58
60
|
RBLNWhisperForConditionalGeneration,
|
59
61
|
RBLNXLMRobertaModel,
|
@@ -27,6 +27,7 @@ from .gemma import RBLNGemmaForCausalLM
|
|
27
27
|
from .gpt2 import RBLNGPT2LMHeadModel
|
28
28
|
from .llama import RBLNLlamaForCausalLM
|
29
29
|
from .midm import RBLNMidmLMHeadModel
|
30
|
+
from .mistral import RBLNMistralForCausalLM
|
30
31
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
31
32
|
from .whisper import RBLNWhisperForConditionalGeneration
|
32
33
|
from .xlm_roberta import RBLNXLMRobertaModel
|
@@ -70,7 +70,7 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
70
70
|
return rt
|
71
71
|
|
72
72
|
@classmethod
|
73
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
73
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
74
74
|
return _TextEncoder(model).eval()
|
75
75
|
|
76
76
|
@classmethod
|
@@ -49,18 +49,19 @@ class DecoderOnlyWrapper(torch.nn.Module):
|
|
49
49
|
self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
|
50
50
|
)
|
51
51
|
self.max_seq_len = max_seq_len
|
52
|
+
self.rope_scaling = getattr(self.config, "rope_scaling", None)
|
52
53
|
self.rotary_emb = self._init_rope()
|
53
54
|
|
54
55
|
def _init_rope(self):
|
55
|
-
if self.
|
56
|
+
if self.rope_scaling is None:
|
56
57
|
rotary_emb = RotaryEmbedding(
|
57
58
|
self.head_dim,
|
58
59
|
max_position_embeddings=self.max_position_embeddings,
|
59
60
|
base=self.config.rope_theta,
|
60
61
|
)
|
61
62
|
else:
|
62
|
-
scaling_type = self.
|
63
|
-
scaling_factor = self.
|
63
|
+
scaling_type = self.rope_scaling["type"]
|
64
|
+
scaling_factor = self.rope_scaling["factor"]
|
64
65
|
if scaling_type == "linear":
|
65
66
|
rotary_emb = LinearScalingRotaryEmbedding(
|
66
67
|
self.head_dim,
|
@@ -20,18 +20,22 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
import glob
|
23
24
|
import logging
|
24
|
-
from abc import ABC
|
25
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
25
|
+
from abc import ABC
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
26
27
|
|
27
28
|
import rebel # noqa: F401
|
28
29
|
import torch # noqa: F401
|
29
|
-
from
|
30
|
+
from safetensors.torch import load_file
|
31
|
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
30
32
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
33
|
+
from transformers.modeling_utils import no_init_weights
|
31
34
|
|
32
35
|
from ....modeling_base import RBLNModel
|
33
36
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
34
37
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
38
|
+
from ...utils.rbln_quantization import replace_quantized_linear_layers
|
35
39
|
|
36
40
|
|
37
41
|
logger = logging.getLogger(__name__)
|
@@ -44,6 +48,12 @@ if TYPE_CHECKING:
|
|
44
48
|
PretrainedConfig,
|
45
49
|
)
|
46
50
|
|
51
|
+
SUPPORTED_QUANTIZATIONS = {
|
52
|
+
"rbln": [
|
53
|
+
"w4a16",
|
54
|
+
],
|
55
|
+
}
|
56
|
+
|
47
57
|
|
48
58
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
49
59
|
mandatory_members = ["main_input_name"]
|
@@ -78,26 +88,98 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
78
88
|
self.decoder = RBLNRuntimeModel(runtime=self.model[1], main_input_name="input_ids")
|
79
89
|
|
80
90
|
@classmethod
|
81
|
-
|
82
|
-
|
83
|
-
|
91
|
+
def get_quantized_model(
|
92
|
+
cls,
|
93
|
+
model_id: str,
|
94
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
95
|
+
revision: Optional[str] = None,
|
96
|
+
force_download: bool = False,
|
97
|
+
cache_dir: Optional[str] = None,
|
98
|
+
subfolder: str = "",
|
99
|
+
local_files_only: bool = False,
|
100
|
+
trust_remote_code: bool = False,
|
101
|
+
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
102
|
+
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
103
|
+
**kwargs,
|
104
|
+
):
|
105
|
+
kwargs = cls.update_kwargs(kwargs)
|
106
|
+
|
107
|
+
config = AutoConfig.from_pretrained(
|
108
|
+
model_id,
|
109
|
+
use_auth_token=use_auth_token,
|
110
|
+
revision=revision,
|
111
|
+
force_download=force_download,
|
112
|
+
cache_dir=cache_dir,
|
113
|
+
trust_remote_code=trust_remote_code,
|
114
|
+
**kwargs,
|
115
|
+
)
|
116
|
+
|
117
|
+
with no_init_weights():
|
118
|
+
model = AutoModelForCausalLM.from_config(config)
|
119
|
+
replace_quantized_linear_layers(model)
|
120
|
+
|
121
|
+
state_dict = {}
|
122
|
+
for safetensor_file in glob.glob(f"{model_id}/*.safetensors"):
|
123
|
+
partial_state_dict = load_file(safetensor_file)
|
124
|
+
state_dict.update(partial_state_dict)
|
125
|
+
|
126
|
+
n_layer = kwargs.get("num_hidden_layers", None)
|
127
|
+
if n_layer is not None:
|
128
|
+
keys_to_delete = []
|
129
|
+
for key in state_dict.keys():
|
130
|
+
parts = key.split(".")
|
131
|
+
if len(parts) > 2 and parts[2].isdigit():
|
132
|
+
layer_num = int(parts[2])
|
133
|
+
if layer_num >= n_layer:
|
134
|
+
keys_to_delete.append(key)
|
135
|
+
|
136
|
+
for key in keys_to_delete:
|
137
|
+
del state_dict[key]
|
138
|
+
|
139
|
+
model.load_state_dict(state_dict)
|
140
|
+
return model
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def get_pytorch_model(
|
144
|
+
cls,
|
145
|
+
*args,
|
146
|
+
**kwargs,
|
147
|
+
) -> "PreTrainedModel":
|
148
|
+
rbln_config_kwargs = kwargs.get("rbln_config_kwargs", {})
|
149
|
+
rbln_quantization = rbln_config_kwargs.get("rbln_quantization", None)
|
150
|
+
|
151
|
+
if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
|
152
|
+
model = cls.get_quantized_model(*args, **kwargs)
|
153
|
+
else:
|
154
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
155
|
+
|
156
|
+
return model
|
84
157
|
|
85
158
|
@classmethod
|
86
159
|
@torch.inference_mode()
|
87
160
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
88
|
-
wrapped_model = cls.
|
161
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
89
162
|
|
90
163
|
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
91
164
|
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
92
165
|
|
93
|
-
|
94
|
-
|
166
|
+
def get_scripted_model():
|
167
|
+
# This function is nested to dealloc the example inputs before compilation.
|
168
|
+
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
169
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=4)
|
170
|
+
|
171
|
+
batch_index = 3
|
172
|
+
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
95
173
|
|
96
|
-
|
97
|
-
|
174
|
+
prefill_scripted_model = torch.jit.trace(
|
175
|
+
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
176
|
+
)
|
177
|
+
dec_scripted_model = torch.jit.trace(
|
178
|
+
wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
|
179
|
+
)
|
180
|
+
return prefill_scripted_model, dec_scripted_model
|
98
181
|
|
99
|
-
prefill_scripted_model =
|
100
|
-
dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
|
182
|
+
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
101
183
|
|
102
184
|
prefill_ir = rebel.torchscript_to_ir(
|
103
185
|
prefill_scripted_model,
|
@@ -133,28 +215,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
133
215
|
model_config: "PretrainedConfig",
|
134
216
|
rbln_max_seq_len: Optional[int] = None,
|
135
217
|
rbln_batch_size: Optional[int] = None,
|
218
|
+
rbln_quantization: Optional[Dict[str, str]] = None,
|
136
219
|
**kwargs,
|
137
220
|
) -> RBLNConfig:
|
138
221
|
meta = {}
|
139
222
|
|
140
223
|
prefill_chunk_size = 128
|
141
224
|
if rbln_max_seq_len is None:
|
142
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
225
|
+
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
226
|
+
model_config, "n_positions", None
|
227
|
+
)
|
228
|
+
if rbln_max_seq_len is None:
|
229
|
+
raise ValueError("`rbln_max_seq_len` should be specified.")
|
143
230
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
144
231
|
|
145
232
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
146
233
|
meta["rbln_batch_size"] = rbln_batch_size
|
147
234
|
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
148
235
|
|
236
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
237
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
238
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
239
|
+
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
240
|
+
|
241
|
+
if rbln_quantization is not None:
|
242
|
+
q_format = rbln_quantization.get("format", None)
|
243
|
+
q_precision = rbln_quantization.get("precision", None)
|
244
|
+
|
245
|
+
if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
246
|
+
raise ValueError(
|
247
|
+
f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
|
248
|
+
f"Possible: {SUPPORTED_QUANTIZATIONS}"
|
249
|
+
)
|
250
|
+
meta["rbln_quantization"] = rbln_quantization
|
251
|
+
|
149
252
|
def get_input_info(
|
150
253
|
batch_size,
|
151
254
|
query_length,
|
152
255
|
):
|
153
|
-
head_dim = (
|
154
|
-
model_config.head_dim
|
155
|
-
if hasattr(model_config, "head_dim")
|
156
|
-
else model_config.hidden_size // model_config.num_attention_heads
|
157
|
-
)
|
158
256
|
input_info = [
|
159
257
|
("input_ids", [batch_size, query_length], "int64"),
|
160
258
|
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
@@ -172,13 +270,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
172
270
|
f"past_key_values_{i}",
|
173
271
|
[
|
174
272
|
rbln_batch_size,
|
175
|
-
|
273
|
+
num_key_value_heads,
|
176
274
|
rbln_max_seq_len,
|
177
275
|
head_dim,
|
178
276
|
],
|
179
277
|
"float32",
|
180
278
|
)
|
181
|
-
for i in range(
|
279
|
+
for i in range(num_hidden_layers * 2)
|
182
280
|
]
|
183
281
|
)
|
184
282
|
|
@@ -295,6 +393,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
295
393
|
raise RuntimeError(
|
296
394
|
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
297
395
|
)
|
396
|
+
|
397
|
+
out_buffers = [
|
398
|
+
torch.empty(
|
399
|
+
size=[
|
400
|
+
1,
|
401
|
+
self.prefill_chunk_size,
|
402
|
+
self.config.vocab_size,
|
403
|
+
],
|
404
|
+
dtype=torch.float32,
|
405
|
+
device="cpu",
|
406
|
+
),
|
407
|
+
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
408
|
+
]
|
409
|
+
|
298
410
|
query_length = input_ids.shape[1]
|
299
411
|
attention_mask = self.prefill_attention_mask.clone()
|
300
412
|
for step in range(0, query_length, self.prefill_chunk_size):
|
@@ -314,7 +426,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
314
426
|
|
315
427
|
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
316
428
|
sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
|
317
|
-
|
429
|
+
|
430
|
+
if step >= self.prefill_chunk_size:
|
431
|
+
attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
318
432
|
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
319
433
|
|
320
434
|
logits, _ = self.prefill_decoder(
|
@@ -322,6 +436,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
322
436
|
attention_mask.contiguous(),
|
323
437
|
sliced_cache_positions.contiguous(),
|
324
438
|
torch.tensor(batch_idx, dtype=torch.int16),
|
439
|
+
out=out_buffers,
|
325
440
|
)
|
326
441
|
logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
327
442
|
|
@@ -39,9 +39,16 @@ from ...models.decoderonly import (
|
|
39
39
|
class GemmaWrapper(DecoderOnlyWrapper):
|
40
40
|
def get_forward_dict(self):
|
41
41
|
forward_dict = {}
|
42
|
-
forward_dict.update(
|
42
|
+
forward_dict.update(
|
43
|
+
{
|
44
|
+
"wrapper": GemmaModel.forward,
|
45
|
+
"model": DecoderOnlyDecoderLayer.forward,
|
46
|
+
"decoder_layer": DecoderOnlyAttention.forward,
|
47
|
+
}
|
48
|
+
)
|
43
49
|
return forward_dict
|
44
50
|
|
51
|
+
|
45
52
|
class GemmaModel:
|
46
53
|
def forward(
|
47
54
|
self,
|
@@ -54,7 +61,7 @@ class GemmaModel:
|
|
54
61
|
use_cache: Optional[bool] = True,
|
55
62
|
output_attentions: Optional[bool] = False,
|
56
63
|
output_hidden_states: Optional[bool] = False,
|
57
|
-
forward_dict
|
64
|
+
forward_dict: Optional[Dict[str, classmethod]] = None,
|
58
65
|
rotary_pos_emb=None,
|
59
66
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
60
67
|
# embed positions
|
@@ -89,7 +96,7 @@ class GemmaModel:
|
|
89
96
|
batch_ids=batch_ids,
|
90
97
|
cos=cos,
|
91
98
|
sin=sin,
|
92
|
-
forward_dict=forward_dict
|
99
|
+
forward_dict=forward_dict,
|
93
100
|
)
|
94
101
|
|
95
102
|
hidden_states = layer_outputs[0]
|