optimum-rbln 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +3 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +4 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
- optimum/rbln/modeling_alias.py +5 -1
- optimum/rbln/modeling_base.py +53 -19
- optimum/rbln/transformers/__init__.py +3 -1
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +4 -3
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +137 -22
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +8 -2
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +10 -3
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +31 -26
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -23,14 +23,19 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import Any, Callable
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
27
|
|
28
|
-
from transformers import GemmaForCausalLM
|
28
|
+
from transformers import GemmaForCausalLM
|
29
29
|
|
30
30
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
31
|
from .gemma_architecture import GemmaWrapper
|
32
32
|
|
33
33
|
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from transformers import PreTrainedModel
|
36
|
+
|
37
|
+
from ....modeling_config import RBLNConfig
|
38
|
+
|
34
39
|
logger = logging.getLogger(__name__)
|
35
40
|
|
36
41
|
|
@@ -46,7 +51,8 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
46
51
|
"""
|
47
52
|
|
48
53
|
@classmethod
|
49
|
-
def
|
54
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
50
56
|
return GemmaWrapper(model, rbln_max_seq_len).eval()
|
51
57
|
|
52
58
|
def __getattr__(self, __name: str) -> Any:
|
@@ -23,23 +23,18 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
27
|
|
28
|
-
from transformers import GPT2LMHeadModel
|
28
|
+
from transformers import GPT2LMHeadModel
|
29
29
|
|
30
|
-
from ....modeling_config import RBLNConfig
|
30
|
+
from ....modeling_config import RBLNConfig
|
31
31
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
32
32
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
33
33
|
|
34
34
|
|
35
35
|
logger = logging.getLogger(__name__)
|
36
36
|
if TYPE_CHECKING:
|
37
|
-
from transformers import
|
38
|
-
AutoFeatureExtractor,
|
39
|
-
AutoProcessor,
|
40
|
-
AutoTokenizer,
|
41
|
-
PretrainedConfig,
|
42
|
-
)
|
37
|
+
from transformers import PreTrainedModel
|
43
38
|
|
44
39
|
|
45
40
|
class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
@@ -57,7 +52,8 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
57
52
|
"""
|
58
53
|
|
59
54
|
@classmethod
|
60
|
-
def
|
55
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
61
57
|
return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
62
58
|
|
63
59
|
def __getattr__(self, __name: str) -> Any:
|
@@ -74,82 +70,3 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
74
70
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
75
71
|
return redirect(val)
|
76
72
|
return val
|
77
|
-
|
78
|
-
@classmethod
|
79
|
-
def _get_rbln_config(
|
80
|
-
cls,
|
81
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
82
|
-
model_config: "PretrainedConfig",
|
83
|
-
rbln_max_seq_len: Optional[int] = None,
|
84
|
-
rbln_batch_size: Optional[int] = None,
|
85
|
-
**kwargs,
|
86
|
-
) -> RBLNConfig:
|
87
|
-
meta = {}
|
88
|
-
|
89
|
-
prefill_chunk_size = 128
|
90
|
-
if rbln_max_seq_len is None: # differenct from llama
|
91
|
-
rbln_max_seq_len = getattr(model_config, "n_positions", None)
|
92
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
93
|
-
|
94
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
95
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
96
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
97
|
-
|
98
|
-
def get_input_info(
|
99
|
-
batch_size,
|
100
|
-
query_length,
|
101
|
-
):
|
102
|
-
head_dim = (
|
103
|
-
model_config.head_dim
|
104
|
-
if hasattr(model_config, "head_dim")
|
105
|
-
else model_config.hidden_size // model_config.n_head
|
106
|
-
)
|
107
|
-
input_info = [
|
108
|
-
("input_ids", [batch_size, query_length], "int64"),
|
109
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
110
|
-
(
|
111
|
-
"cache_position",
|
112
|
-
[batch_size, query_length],
|
113
|
-
"int32",
|
114
|
-
),
|
115
|
-
("batch_position", [], "int16"),
|
116
|
-
]
|
117
|
-
|
118
|
-
input_info.extend(
|
119
|
-
[
|
120
|
-
(
|
121
|
-
f"past_key_values_{i}",
|
122
|
-
[
|
123
|
-
rbln_batch_size,
|
124
|
-
model_config.n_head, # differenct from llama
|
125
|
-
rbln_max_seq_len,
|
126
|
-
head_dim,
|
127
|
-
],
|
128
|
-
"float32",
|
129
|
-
)
|
130
|
-
for i in range(model_config.n_layer * 2) # differenct from llama
|
131
|
-
]
|
132
|
-
)
|
133
|
-
|
134
|
-
return input_info
|
135
|
-
|
136
|
-
prefill_input_info = get_input_info(
|
137
|
-
batch_size=1,
|
138
|
-
query_length=prefill_chunk_size,
|
139
|
-
)
|
140
|
-
dec_input_info = get_input_info(
|
141
|
-
batch_size=rbln_batch_size,
|
142
|
-
query_length=1,
|
143
|
-
)
|
144
|
-
|
145
|
-
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
146
|
-
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
147
|
-
|
148
|
-
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
149
|
-
|
150
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
151
|
-
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
152
|
-
_rbln_meta=meta,
|
153
|
-
)
|
154
|
-
|
155
|
-
return rbln_config
|
@@ -23,14 +23,19 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import Any, Callable
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
27
|
|
28
|
-
from transformers import LlamaForCausalLM
|
28
|
+
from transformers import LlamaForCausalLM
|
29
29
|
|
30
30
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
31
|
from .llama_architecture import LlamaWrapper
|
32
32
|
|
33
33
|
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from transformers import PreTrainedModel
|
36
|
+
|
37
|
+
from ....modeling_config import RBLNConfig
|
38
|
+
|
34
39
|
logger = logging.getLogger(__name__)
|
35
40
|
|
36
41
|
|
@@ -46,7 +51,8 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
46
51
|
"""
|
47
52
|
|
48
53
|
@classmethod
|
49
|
-
def
|
54
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
55
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
50
56
|
return LlamaWrapper(model, rbln_max_seq_len).eval()
|
51
57
|
|
52
58
|
def __getattr__(self, __name: str) -> Any:
|
@@ -23,11 +23,9 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
27
|
|
28
|
-
from
|
29
|
-
|
30
|
-
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
28
|
+
from ....modeling_config import RBLNConfig
|
31
29
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
32
30
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
33
31
|
from .midm_architecture import (
|
@@ -38,10 +36,7 @@ from .midm_architecture import (
|
|
38
36
|
logger = logging.getLogger(__name__)
|
39
37
|
if TYPE_CHECKING:
|
40
38
|
from transformers import (
|
41
|
-
|
42
|
-
AutoProcessor,
|
43
|
-
AutoTokenizer,
|
44
|
-
PretrainedConfig,
|
39
|
+
PreTrainedModel,
|
45
40
|
)
|
46
41
|
|
47
42
|
|
@@ -60,7 +55,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
60
55
|
"""
|
61
56
|
|
62
57
|
@classmethod
|
63
|
-
def
|
58
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
59
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
64
60
|
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
65
61
|
|
66
62
|
def __getattr__(self, __name: str) -> Any:
|
@@ -77,82 +73,3 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
77
73
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
78
74
|
return redirect(val)
|
79
75
|
return val
|
80
|
-
|
81
|
-
@classmethod
|
82
|
-
def _get_rbln_config(
|
83
|
-
cls,
|
84
|
-
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
85
|
-
model_config: "PretrainedConfig",
|
86
|
-
rbln_max_seq_len: Optional[int] = None,
|
87
|
-
rbln_batch_size: Optional[int] = None,
|
88
|
-
**kwargs,
|
89
|
-
) -> RBLNConfig:
|
90
|
-
meta = {}
|
91
|
-
|
92
|
-
prefill_chunk_size = 128
|
93
|
-
if rbln_max_seq_len is None:
|
94
|
-
rbln_max_seq_len = getattr(model_config, "n_positions", None)
|
95
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
96
|
-
|
97
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
98
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
99
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
100
|
-
|
101
|
-
def get_input_info(
|
102
|
-
batch_size,
|
103
|
-
query_length,
|
104
|
-
):
|
105
|
-
head_dim = (
|
106
|
-
model_config.head_dim
|
107
|
-
if hasattr(model_config, "head_dim")
|
108
|
-
else model_config.hidden_size // model_config.n_head
|
109
|
-
)
|
110
|
-
input_info = [
|
111
|
-
("input_ids", [batch_size, query_length], "int64"),
|
112
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
113
|
-
(
|
114
|
-
"cache_position",
|
115
|
-
[batch_size, query_length],
|
116
|
-
"int32",
|
117
|
-
),
|
118
|
-
("batch_position", [], "int16"),
|
119
|
-
]
|
120
|
-
|
121
|
-
input_info.extend(
|
122
|
-
[
|
123
|
-
(
|
124
|
-
f"past_key_values_{i}",
|
125
|
-
[
|
126
|
-
rbln_batch_size,
|
127
|
-
model_config.n_head,
|
128
|
-
rbln_max_seq_len,
|
129
|
-
head_dim,
|
130
|
-
],
|
131
|
-
"float32",
|
132
|
-
)
|
133
|
-
for i in range(model_config.n_layer * 2)
|
134
|
-
]
|
135
|
-
)
|
136
|
-
|
137
|
-
return input_info
|
138
|
-
|
139
|
-
prefill_input_info = get_input_info(
|
140
|
-
batch_size=1,
|
141
|
-
query_length=prefill_chunk_size,
|
142
|
-
)
|
143
|
-
dec_input_info = get_input_info(
|
144
|
-
batch_size=rbln_batch_size,
|
145
|
-
query_length=1,
|
146
|
-
)
|
147
|
-
|
148
|
-
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
149
|
-
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
150
|
-
|
151
|
-
dec_rbln_runtime_config.batch_size = rbln_batch_size
|
152
|
-
|
153
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
154
|
-
[prefill_rbln_runtime_config, dec_rbln_runtime_config],
|
155
|
-
_rbln_meta=meta,
|
156
|
-
)
|
157
|
-
|
158
|
-
return rbln_config
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_mistral import RBLNMistralForCausalLM
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
26
|
+
|
27
|
+
|
28
|
+
class MistralForCausalLMWrapper(DecoderOnlyWrapper):
|
29
|
+
pass
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable
|
27
|
+
|
28
|
+
from transformers import MistralForCausalLM
|
29
|
+
|
30
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
31
|
+
from .mistral_architecture import MistralForCausalLMWrapper
|
32
|
+
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from transformers import PreTrainedModel
|
36
|
+
|
37
|
+
from ....modeling_config import RBLNConfig
|
38
|
+
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
44
|
+
"""
|
45
|
+
The Llama Model transformer with a language modeling head (linear layer) on top.
|
46
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
47
|
+
|
48
|
+
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
49
|
+
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
50
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
51
|
+
- compiling the resulting graph using the RBLN compiler.
|
52
|
+
"""
|
53
|
+
|
54
|
+
@classmethod
|
55
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
56
|
+
rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
|
57
|
+
return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
|
58
|
+
|
59
|
+
def __getattr__(self, __name: str) -> Any:
|
60
|
+
def redirect(func):
|
61
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
62
|
+
|
63
|
+
val = getattr(MistralForCausalLM, __name)
|
64
|
+
|
65
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
66
|
+
return redirect(val)
|
67
|
+
|
68
|
+
return val
|
@@ -70,7 +70,7 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
70
70
|
auto_model_class = AutoModelForMaskedLM
|
71
71
|
|
72
72
|
@classmethod
|
73
|
-
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
73
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
74
74
|
return _Wav2Vec2(model).eval()
|
75
75
|
|
76
76
|
@classmethod
|
@@ -36,6 +36,7 @@ logger = logging.getLogger(__name__)
|
|
36
36
|
if TYPE_CHECKING:
|
37
37
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
38
38
|
|
39
|
+
|
39
40
|
class RBLNXLMRobertaModel(RBLNModel):
|
40
41
|
auto_model_class = AutoModel # feature extraction
|
41
42
|
original_model_class = XLMRobertaModel
|
@@ -81,7 +82,6 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
81
82
|
rbln_model_input_names: Optional[List[str]] = None,
|
82
83
|
rbln_batch_size: Optional[int] = None,
|
83
84
|
) -> RBLNConfig:
|
84
|
-
|
85
85
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
86
86
|
model_config, "max_position_embeddings", None
|
87
87
|
)
|
@@ -118,7 +118,13 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
118
118
|
|
119
119
|
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
120
120
|
|
121
|
-
def forward(
|
121
|
+
def forward(
|
122
|
+
self,
|
123
|
+
input_ids: "torch.Tensor",
|
124
|
+
attention_mask: "torch.Tensor",
|
125
|
+
token_type_ids: "torch.Tensor" = None,
|
126
|
+
**kwargs,
|
127
|
+
):
|
122
128
|
if token_type_ids is None:
|
123
129
|
token_type_ids = torch.zeros_like(input=input_ids, dtype=torch.int64)
|
124
130
|
output = super().forward(input_ids, attention_mask, token_type_ids)
|
File without changes
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
from typing import Any, List
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from torch.nn import Linear, Parameter
|
29
|
+
from torch.nn import functional as F
|
30
|
+
|
31
|
+
|
32
|
+
QUANTIZED_WEIGHTS = [
|
33
|
+
"q_proj",
|
34
|
+
"k_proj",
|
35
|
+
"v_proj",
|
36
|
+
"o_proj",
|
37
|
+
"gate_proj",
|
38
|
+
"up_proj",
|
39
|
+
"down_proj",
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
def replace_quantized_linear_layers(
|
44
|
+
module: torch.nn.Module,
|
45
|
+
) -> None:
|
46
|
+
"""Replace target(quantized) linear layer's forward to qlinear forward
|
47
|
+
|
48
|
+
Args:
|
49
|
+
module (torch.nn.Module): The module containing the linear layers to be replaced.
|
50
|
+
For example, this could be an instance of a model like
|
51
|
+
LlamaForCausalLM().
|
52
|
+
"""
|
53
|
+
processed_names: List[str] = []
|
54
|
+
|
55
|
+
for name, layer in module.named_modules():
|
56
|
+
is_replace_linear = name.split(".")[-1] in QUANTIZED_WEIGHTS
|
57
|
+
if isinstance(layer, torch.nn.Linear) and is_replace_linear:
|
58
|
+
*parent_address, child_name = name.split(".")
|
59
|
+
parent = access_attribute(module, parent_address)
|
60
|
+
setattr(parent, child_name, get_qlinear(layer))
|
61
|
+
processed_names.append(name)
|
62
|
+
names_repr = ", ".join(processed_names)
|
63
|
+
print(f"Replace the following linear layers as qlinear layer:\n {{{names_repr}}}")
|
64
|
+
|
65
|
+
|
66
|
+
def access_attribute(obj: Any, tokens: List[str]) -> Any:
|
67
|
+
"""Get attribute of given object.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
obj: object
|
71
|
+
|
72
|
+
tokens (List[str]): attribute names to access, must be in correct order
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Any: accessed attribute
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
AttributeError: If attribute doesn't exists
|
79
|
+
"""
|
80
|
+
if len(tokens) == 0:
|
81
|
+
return obj
|
82
|
+
return access_attribute(getattr(obj, tokens[0]), tokens[1:])
|
83
|
+
|
84
|
+
|
85
|
+
def get_qlinear(layer: Linear):
|
86
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
87
|
+
"""Perform weight-only quantized linear layer.
|
88
|
+
|
89
|
+
Forward workflow:
|
90
|
+
- cast weight to high precision
|
91
|
+
- multiply scale factor to weight
|
92
|
+
- call torch.nn.functional linear
|
93
|
+
Note:
|
94
|
+
- Please don't modify following workflow
|
95
|
+
- if the workflow must be changed please contact Rebellions
|
96
|
+
"""
|
97
|
+
if inputs.dtype != self.scales.dtype:
|
98
|
+
raise TypeError(f"Expected tensor of dtype {self.scales.dtype} but got {inputs.dtype}")
|
99
|
+
w_fp = self.weight.type(inputs.dtype)
|
100
|
+
w_fp *= self.scales.view(-1, 1)
|
101
|
+
return F.linear(inputs, w_fp, self.bias)
|
102
|
+
|
103
|
+
keep = layer.weight.to(torch.int8)
|
104
|
+
layer.weight = None
|
105
|
+
del layer.weight
|
106
|
+
layer.weight = Parameter(keep, requires_grad=False)
|
107
|
+
layer.scales = Parameter(torch.ones(layer.out_features, dtype=torch.float32), requires_grad=False)
|
108
|
+
layer.forward = lambda *args, **kwargs: forward(layer, *args, **kwargs)
|
109
|
+
return layer
|
@@ -53,8 +53,7 @@ def is_rbln_available() -> bool:
|
|
53
53
|
|
54
54
|
|
55
55
|
def check_version_compats() -> None:
|
56
|
-
warnings.filterwarnings(action="always", category=ImportWarning)
|
57
|
-
|
56
|
+
warnings.filterwarnings(action="always", category=ImportWarning, module="optimum.rbln")
|
58
57
|
my_version = importlib.metadata.version("optimum-rbln")
|
59
58
|
target_version = list(filter(lambda v: Version(my_version) > Version(v), RBLN_VERSION_COMPATS.keys()))[0]
|
60
59
|
for compat in RBLN_VERSION_COMPATS[target_version]:
|
@@ -70,5 +69,3 @@ def check_version_compats() -> None:
|
|
70
69
|
"Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
|
71
70
|
ImportWarning,
|
72
71
|
)
|
73
|
-
|
74
|
-
warnings.resetwarnings()
|
@@ -42,8 +42,9 @@ class RBLNPytorchRuntime:
|
|
42
42
|
return self.forward(*args, **kwds)
|
43
43
|
|
44
44
|
def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
|
45
|
+
# filtering uselss args or kwarg such as None.
|
45
46
|
args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
|
46
|
-
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor), kwargs.items()))
|
47
|
+
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
|
47
48
|
output = self.runtime(*args, **kwargs)
|
48
49
|
return output
|
49
50
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.9
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
|
5
5
|
It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
6
6
|
Keywords: transformers,diffusers,inference,rbln,atom,rebel
|
@@ -21,10 +21,12 @@ Project-URL: Homepage, https://rebellions.ai
|
|
21
21
|
Project-URL: Documentation, https://docs.rbln.ai
|
22
22
|
Requires-Python: <3.11,>=3.8
|
23
23
|
Requires-Dist: torch<=2.2.1
|
24
|
+
Requires-Dist: torchvision<=0.17.1
|
25
|
+
Requires-Dist: torchaudio<=2.2.1
|
24
26
|
Requires-Dist: optimum<=1.20.0
|
25
27
|
Requires-Dist: accelerate>=0.28.0
|
26
|
-
Requires-Dist: transformers<=4.40.2
|
27
|
-
Requires-Dist: diffusers<=0.
|
28
|
+
Requires-Dist: transformers<=4.40.2,>=4.38.0
|
29
|
+
Requires-Dist: diffusers<=0.30.1
|
28
30
|
Requires-Dist: einops>=0.8.0
|
29
31
|
Requires-Dist: packaging>=24.1
|
30
32
|
Requires-Dist: pytest>=8.1.1; extra == "tests"
|
@@ -99,6 +101,11 @@ To install optional dependencies from all groups, specify `-G:all` option.
|
|
99
101
|
pdm install -G:all
|
100
102
|
```
|
101
103
|
|
104
|
+
If you want to install optimum-rbln as [editable mode](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs) in existing venv,
|
105
|
+
```bash
|
106
|
+
(venv) pip install -e .
|
107
|
+
```
|
108
|
+
|
102
109
|
## How to use it?
|
103
110
|
|
104
111
|
### Quick Start
|