optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
|
+
|
28
|
+
from transformers import AutoModel, BartConfig, BartModel, PretrainedConfig
|
29
|
+
|
30
|
+
from ....modeling_base import RBLNModel
|
31
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNBartModel(RBLNModel):
|
41
|
+
auto_model_class = AutoModel # feature extraction
|
42
|
+
original_model_class = BartModel
|
43
|
+
original_config_class = BartConfig
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def _get_rbln_config(
|
47
|
+
cls,
|
48
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
49
|
+
model_config: Optional["PretrainedConfig"] = None,
|
50
|
+
rbln_kwargs: Dict[str, Any] = {},
|
51
|
+
) -> RBLNConfig:
|
52
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
53
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
54
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
55
|
+
|
56
|
+
max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
|
57
|
+
|
58
|
+
if rbln_max_seq_len is None:
|
59
|
+
rbln_max_seq_len = max_position_embeddings
|
60
|
+
if rbln_max_seq_len is None:
|
61
|
+
for tokenizer in preprocessors:
|
62
|
+
if hasattr(tokenizer, "model_max_length"):
|
63
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
64
|
+
break
|
65
|
+
if rbln_max_seq_len is None:
|
66
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
67
|
+
|
68
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
69
|
+
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
70
|
+
|
71
|
+
if rbln_model_input_names is None:
|
72
|
+
for tokenizer in preprocessors:
|
73
|
+
if hasattr(tokenizer, "model_input_names"):
|
74
|
+
rbln_model_input_names = tokenizer.model_input_names
|
75
|
+
# BartModel's forward() does not take token_type_ids as input.
|
76
|
+
# (Added because some of the tokenizers includes 'token_type_ids')
|
77
|
+
if "token_type_ids" in rbln_model_input_names:
|
78
|
+
rbln_model_input_names.remove("token_type_ids")
|
79
|
+
break
|
80
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
81
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
82
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
83
|
+
input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
|
84
|
+
raise ValueError(
|
85
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
86
|
+
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
|
87
|
+
)
|
88
|
+
|
89
|
+
if rbln_batch_size is None:
|
90
|
+
rbln_batch_size = 1
|
91
|
+
|
92
|
+
input_info = [
|
93
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
94
|
+
for model_input_name in rbln_model_input_names
|
95
|
+
]
|
96
|
+
|
97
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
98
|
+
|
99
|
+
rbln_config = RBLNConfig(
|
100
|
+
rbln_cls=cls.__name__,
|
101
|
+
compile_cfgs=[rbln_compile_config],
|
102
|
+
rbln_kwargs=rbln_kwargs,
|
103
|
+
)
|
104
|
+
|
105
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
106
|
+
return rbln_config
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_bert import RBLNBertModel
|
@@ -0,0 +1,102 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
|
+
|
28
|
+
from transformers import AutoModel, BertConfig, BertModel, PretrainedConfig
|
29
|
+
|
30
|
+
from ....modeling_base import RBLNModel
|
31
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNBertModel(RBLNModel):
|
41
|
+
auto_model_class = AutoModel # feature extraction
|
42
|
+
original_model_class = BertModel
|
43
|
+
original_config_class = BertConfig
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def _get_rbln_config(
|
47
|
+
cls,
|
48
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
49
|
+
model_config: Optional["PretrainedConfig"] = None,
|
50
|
+
rbln_kwargs: Dict[str, Any] = {},
|
51
|
+
) -> RBLNConfig:
|
52
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
53
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
54
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
55
|
+
|
56
|
+
max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
|
57
|
+
|
58
|
+
if rbln_max_seq_len is None:
|
59
|
+
rbln_max_seq_len = max_position_embeddings
|
60
|
+
if rbln_max_seq_len is None:
|
61
|
+
for tokenizer in preprocessors:
|
62
|
+
if hasattr(tokenizer, "model_max_length"):
|
63
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
64
|
+
break
|
65
|
+
if rbln_max_seq_len is None:
|
66
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
67
|
+
|
68
|
+
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
69
|
+
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
70
|
+
|
71
|
+
if rbln_model_input_names is None:
|
72
|
+
for tokenizer in preprocessors:
|
73
|
+
if hasattr(tokenizer, "model_input_names"):
|
74
|
+
rbln_model_input_names = tokenizer.model_input_names
|
75
|
+
break
|
76
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
77
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
78
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
79
|
+
input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
|
80
|
+
raise ValueError(
|
81
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
82
|
+
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
|
83
|
+
)
|
84
|
+
|
85
|
+
if rbln_batch_size is None:
|
86
|
+
rbln_batch_size = 1
|
87
|
+
|
88
|
+
input_info = [
|
89
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
90
|
+
for model_input_name in rbln_model_input_names
|
91
|
+
]
|
92
|
+
|
93
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
94
|
+
|
95
|
+
rbln_config = RBLNConfig(
|
96
|
+
rbln_cls=cls.__name__,
|
97
|
+
compile_cfgs=[rbln_compile_config],
|
98
|
+
rbln_kwargs=rbln_kwargs,
|
99
|
+
)
|
100
|
+
|
101
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
102
|
+
return rbln_config
|
@@ -21,4 +21,4 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
24
|
+
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
@@ -22,14 +22,23 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING, Optional, Union
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
26
26
|
|
27
27
|
import torch
|
28
|
-
from transformers import
|
28
|
+
from transformers import (
|
29
|
+
AutoConfig,
|
30
|
+
AutoModel,
|
31
|
+
CLIPTextConfig,
|
32
|
+
CLIPTextModel,
|
33
|
+
CLIPTextModelWithProjection,
|
34
|
+
CLIPVisionConfig,
|
35
|
+
CLIPVisionModel,
|
36
|
+
)
|
37
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
29
38
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
30
39
|
|
31
40
|
from ....modeling_base import RBLNModel
|
32
|
-
from ....modeling_config import
|
41
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
33
42
|
|
34
43
|
|
35
44
|
logger = logging.getLogger(__name__)
|
@@ -41,12 +50,10 @@ if TYPE_CHECKING:
|
|
41
50
|
class _TextEncoder(torch.nn.Module):
|
42
51
|
def __init__(self, enc: "CLIPTextModel"):
|
43
52
|
super().__init__()
|
44
|
-
enc.config.return_dict = False
|
45
|
-
enc.config.output_hidden_states = True
|
46
53
|
self.enc = enc
|
47
54
|
|
48
55
|
def forward(self, inp):
|
49
|
-
enc_out = self.enc(inp)
|
56
|
+
enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
|
50
57
|
return enc_out
|
51
58
|
|
52
59
|
|
@@ -55,9 +62,6 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
55
62
|
original_model_class = CLIPTextModel
|
56
63
|
original_config_class = CLIPTextConfig
|
57
64
|
|
58
|
-
def __post_init__(self, **kwargs):
|
59
|
-
self.dtype = torch.float32
|
60
|
-
|
61
65
|
@classmethod
|
62
66
|
def from_pretrained(cls, *args, **kwargs):
|
63
67
|
configtmp = AutoConfig.from_pretrained
|
@@ -78,28 +82,32 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
78
82
|
cls,
|
79
83
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
80
84
|
model_config: "CLIPTextConfig",
|
85
|
+
rbln_kwargs: Dict[str, Any] = {},
|
81
86
|
rbln_batch_size: Optional[int] = None,
|
82
|
-
rbln_img_width: Optional[int] = None,
|
83
|
-
rbln_img_height: Optional[int] = None,
|
84
87
|
) -> RBLNConfig:
|
85
|
-
|
88
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
86
89
|
if rbln_batch_size is None:
|
87
90
|
rbln_batch_size = 1
|
88
91
|
|
89
|
-
|
90
|
-
input_info=[
|
91
|
-
(
|
92
|
-
"input_ids",
|
93
|
-
[
|
94
|
-
rbln_batch_size,
|
95
|
-
model_config.max_position_embeddings,
|
96
|
-
],
|
97
|
-
"int64",
|
98
|
-
),
|
99
|
-
],
|
100
|
-
)
|
92
|
+
model_config.return_dict = False
|
101
93
|
|
102
|
-
|
94
|
+
input_info = [
|
95
|
+
(
|
96
|
+
"input_ids",
|
97
|
+
[
|
98
|
+
rbln_batch_size,
|
99
|
+
model_config.max_position_embeddings,
|
100
|
+
],
|
101
|
+
"int64",
|
102
|
+
),
|
103
|
+
]
|
104
|
+
|
105
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
106
|
+
rbln_config = RBLNConfig(
|
107
|
+
rbln_cls=cls.__name__,
|
108
|
+
compile_cfgs=[rbln_compile_config],
|
109
|
+
rbln_kwargs=rbln_kwargs,
|
110
|
+
)
|
103
111
|
return rbln_config
|
104
112
|
|
105
113
|
def forward(self, input_ids: "torch.Tensor", **kwargs):
|
@@ -113,3 +121,97 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
113
121
|
|
114
122
|
class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
|
115
123
|
original_model_class = CLIPTextModelWithProjection
|
124
|
+
|
125
|
+
|
126
|
+
class _VisionEncoder(torch.nn.Module):
|
127
|
+
def __init__(self, enc: CLIPVisionModel):
|
128
|
+
super().__init__()
|
129
|
+
self.enc = enc
|
130
|
+
|
131
|
+
def forward(self, inp):
|
132
|
+
enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
|
133
|
+
return enc_out
|
134
|
+
|
135
|
+
|
136
|
+
class RBLNCLIPVisionModel(RBLNModel):
|
137
|
+
original_model_class = CLIPVisionModel
|
138
|
+
original_config_class = CLIPVisionConfig
|
139
|
+
|
140
|
+
@classmethod
|
141
|
+
def from_pretrained(cls, *args, **kwargs):
|
142
|
+
configtmp = AutoConfig.from_pretrained
|
143
|
+
modeltmp = AutoModel.from_pretrained
|
144
|
+
AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
|
145
|
+
AutoModel.from_pretrained = cls.original_model_class.from_pretrained
|
146
|
+
rt = super().from_pretrained(*args, **kwargs)
|
147
|
+
AutoConfig.from_pretrained = configtmp
|
148
|
+
AutoModel.from_pretrained = modeltmp
|
149
|
+
return rt
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
153
|
+
return _VisionEncoder(model).eval()
|
154
|
+
|
155
|
+
@classmethod
|
156
|
+
def _get_rbln_config(
|
157
|
+
cls,
|
158
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
159
|
+
model_config: "CLIPTextConfig",
|
160
|
+
rbln_kwargs: Dict[str, Any] = {},
|
161
|
+
) -> RBLNConfig:
|
162
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", 1)
|
163
|
+
rbln_image_size = rbln_kwargs.get("image_size", None)
|
164
|
+
|
165
|
+
if rbln_image_size is None:
|
166
|
+
rbln_image_size = getattr(model_config, "image_size", None)
|
167
|
+
|
168
|
+
if isinstance(rbln_image_size, int):
|
169
|
+
rbln_image_size = (rbln_image_size, rbln_image_size)
|
170
|
+
|
171
|
+
if rbln_image_size is None:
|
172
|
+
raise ValueError("`rbln_image_size` should be specified!")
|
173
|
+
|
174
|
+
rbln_compile_config = RBLNCompileConfig(
|
175
|
+
input_info=[
|
176
|
+
(
|
177
|
+
"pixel_values",
|
178
|
+
[
|
179
|
+
rbln_batch_size,
|
180
|
+
3,
|
181
|
+
rbln_image_size[0],
|
182
|
+
rbln_image_size[1],
|
183
|
+
],
|
184
|
+
"float32",
|
185
|
+
)
|
186
|
+
]
|
187
|
+
)
|
188
|
+
|
189
|
+
rbln_config = RBLNConfig(
|
190
|
+
rbln_cls=cls.__name__,
|
191
|
+
compile_cfgs=[rbln_compile_config],
|
192
|
+
rbln_kwargs=rbln_kwargs,
|
193
|
+
)
|
194
|
+
|
195
|
+
rbln_config.model_cfg.update(
|
196
|
+
{
|
197
|
+
"batch_size": rbln_batch_size,
|
198
|
+
"image_size": rbln_image_size,
|
199
|
+
}
|
200
|
+
)
|
201
|
+
|
202
|
+
return rbln_config
|
203
|
+
|
204
|
+
def forward(
|
205
|
+
self,
|
206
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
207
|
+
**kwargs,
|
208
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
209
|
+
if len(kwargs) > 0 and any(kwargs.values()):
|
210
|
+
logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
|
211
|
+
|
212
|
+
output = super().forward(pixel_values)
|
213
|
+
return BaseModelOutputWithPooling(
|
214
|
+
last_hidden_state=output[0],
|
215
|
+
pooler_output=output[1],
|
216
|
+
hidden_states=output[2:],
|
217
|
+
)
|
@@ -93,17 +93,29 @@ class DecoderOnlyWrapper(torch.nn.Module):
|
|
93
93
|
|
94
94
|
def forward(
|
95
95
|
self,
|
96
|
-
|
96
|
+
input_ids_or_inputs_embeds,
|
97
97
|
attention_mask,
|
98
98
|
cache_position,
|
99
99
|
batch_position,
|
100
|
+
query_idx,
|
100
101
|
*past_key_values,
|
101
102
|
):
|
102
|
-
if
|
103
|
+
if input_ids_or_inputs_embeds.shape[1] == 1:
|
103
104
|
rbln_batch_position = None
|
104
105
|
else:
|
105
106
|
rbln_batch_position = batch_position
|
106
107
|
|
108
|
+
if input_ids_or_inputs_embeds.ndim == 2:
|
109
|
+
# input_ids
|
110
|
+
input_ids = input_ids_or_inputs_embeds
|
111
|
+
inputs_embeds = None
|
112
|
+
elif input_ids_or_inputs_embeds.ndim == 3:
|
113
|
+
# inputs_embeds
|
114
|
+
input_ids = None
|
115
|
+
inputs_embeds = input_ids_or_inputs_embeds
|
116
|
+
else:
|
117
|
+
raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
|
118
|
+
|
107
119
|
# Formatting list of past_kv to DynamicCache class.
|
108
120
|
past_key_values = RebelDynamicCache.from_input_format(
|
109
121
|
cache_position,
|
@@ -115,6 +127,7 @@ class DecoderOnlyWrapper(torch.nn.Module):
|
|
115
127
|
outputs = forward_dict["wrapper"](
|
116
128
|
self.model,
|
117
129
|
input_ids=input_ids,
|
130
|
+
inputs_embeds=inputs_embeds,
|
118
131
|
attention_mask=attention_mask,
|
119
132
|
position_ids=cache_position,
|
120
133
|
past_key_values=past_key_values,
|
@@ -124,11 +137,14 @@ class DecoderOnlyWrapper(torch.nn.Module):
|
|
124
137
|
)
|
125
138
|
|
126
139
|
hidden_states = outputs[0]
|
140
|
+
if batch_position >= 0:
|
141
|
+
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
142
|
+
|
127
143
|
logits = self.lm_head(hidden_states)
|
128
144
|
|
129
145
|
output = (logits,) + outputs[1:]
|
130
146
|
|
131
|
-
return output, batch_position
|
147
|
+
return output, batch_position + query_idx
|
132
148
|
|
133
149
|
|
134
150
|
class DecoderOnlyAttention:
|
@@ -323,8 +339,16 @@ class DecoderOnlyModel:
|
|
323
339
|
forward_dict: Optional[Dict[str, classmethod]] = None,
|
324
340
|
rotary_pos_emb=None,
|
325
341
|
) -> BaseModelOutputWithPast:
|
342
|
+
# retrieve input_ids and inputs_embeds
|
343
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
344
|
+
raise ValueError(
|
345
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
346
|
+
)
|
347
|
+
|
326
348
|
# embed positions
|
327
|
-
inputs_embeds
|
349
|
+
if inputs_embeds is None:
|
350
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
351
|
+
|
328
352
|
hidden_states = inputs_embeds
|
329
353
|
attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
|
330
354
|
|