optimum-rbln 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +3 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  4. optimum/rbln/diffusers/models/controlnet.py +4 -3
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
  7. optimum/rbln/modeling_alias.py +5 -1
  8. optimum/rbln/modeling_base.py +53 -19
  9. optimum/rbln/transformers/__init__.py +3 -1
  10. optimum/rbln/transformers/models/__init__.py +1 -0
  11. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +4 -3
  13. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +137 -22
  14. optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
  16. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
  17. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
  18. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
  19. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  20. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  21. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  22. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  23. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +8 -2
  24. optimum/rbln/transformers/utils/__init__.py +0 -0
  25. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  26. optimum/rbln/utils/import_utils.py +1 -4
  27. optimum/rbln/utils/runtime_utils.py +2 -1
  28. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +10 -3
  29. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +31 -26
  30. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  31. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -23,14 +23,19 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import Any, Callable
26
+ from typing import TYPE_CHECKING, Any, Callable
27
27
 
28
- from transformers import GemmaForCausalLM, PreTrainedModel
28
+ from transformers import GemmaForCausalLM
29
29
 
30
30
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
31
  from .gemma_architecture import GemmaWrapper
32
32
 
33
33
 
34
+ if TYPE_CHECKING:
35
+ from transformers import PreTrainedModel
36
+
37
+ from ....modeling_config import RBLNConfig
38
+
34
39
  logger = logging.getLogger(__name__)
35
40
 
36
41
 
@@ -46,7 +51,8 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
46
51
  """
47
52
 
48
53
  @classmethod
49
- def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
54
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
50
56
  return GemmaWrapper(model, rbln_max_seq_len).eval()
51
57
 
52
58
  def __getattr__(self, __name: str) -> Any:
@@ -23,23 +23,18 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Callable
27
27
 
28
- from transformers import GPT2LMHeadModel, PretrainedConfig, PreTrainedModel
28
+ from transformers import GPT2LMHeadModel
29
29
 
30
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
30
+ from ....modeling_config import RBLNConfig
31
31
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
32
32
  from .gpt2_architecture import GPT2LMHeadModelWrapper
33
33
 
34
34
 
35
35
  logger = logging.getLogger(__name__)
36
36
  if TYPE_CHECKING:
37
- from transformers import (
38
- AutoFeatureExtractor,
39
- AutoProcessor,
40
- AutoTokenizer,
41
- PretrainedConfig,
42
- )
37
+ from transformers import PreTrainedModel
43
38
 
44
39
 
45
40
  class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -57,7 +52,8 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
57
52
  """
58
53
 
59
54
  @classmethod
60
- def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
55
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
61
57
  return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
62
58
 
63
59
  def __getattr__(self, __name: str) -> Any:
@@ -74,82 +70,3 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
74
70
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
75
71
  return redirect(val)
76
72
  return val
77
-
78
- @classmethod
79
- def _get_rbln_config(
80
- cls,
81
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
82
- model_config: "PretrainedConfig",
83
- rbln_max_seq_len: Optional[int] = None,
84
- rbln_batch_size: Optional[int] = None,
85
- **kwargs,
86
- ) -> RBLNConfig:
87
- meta = {}
88
-
89
- prefill_chunk_size = 128
90
- if rbln_max_seq_len is None: # differenct from llama
91
- rbln_max_seq_len = getattr(model_config, "n_positions", None)
92
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
93
-
94
- meta["rbln_max_seq_len"] = rbln_max_seq_len
95
- meta["rbln_batch_size"] = rbln_batch_size
96
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
97
-
98
- def get_input_info(
99
- batch_size,
100
- query_length,
101
- ):
102
- head_dim = (
103
- model_config.head_dim
104
- if hasattr(model_config, "head_dim")
105
- else model_config.hidden_size // model_config.n_head
106
- )
107
- input_info = [
108
- ("input_ids", [batch_size, query_length], "int64"),
109
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
110
- (
111
- "cache_position",
112
- [batch_size, query_length],
113
- "int32",
114
- ),
115
- ("batch_position", [], "int16"),
116
- ]
117
-
118
- input_info.extend(
119
- [
120
- (
121
- f"past_key_values_{i}",
122
- [
123
- rbln_batch_size,
124
- model_config.n_head, # differenct from llama
125
- rbln_max_seq_len,
126
- head_dim,
127
- ],
128
- "float32",
129
- )
130
- for i in range(model_config.n_layer * 2) # differenct from llama
131
- ]
132
- )
133
-
134
- return input_info
135
-
136
- prefill_input_info = get_input_info(
137
- batch_size=1,
138
- query_length=prefill_chunk_size,
139
- )
140
- dec_input_info = get_input_info(
141
- batch_size=rbln_batch_size,
142
- query_length=1,
143
- )
144
-
145
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
146
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
147
-
148
- dec_rbln_runtime_config.batch_size = rbln_batch_size
149
-
150
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
151
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
152
- _rbln_meta=meta,
153
- )
154
-
155
- return rbln_config
@@ -23,14 +23,19 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import Any, Callable
26
+ from typing import TYPE_CHECKING, Any, Callable
27
27
 
28
- from transformers import LlamaForCausalLM, PreTrainedModel
28
+ from transformers import LlamaForCausalLM
29
29
 
30
30
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
31
  from .llama_architecture import LlamaWrapper
32
32
 
33
33
 
34
+ if TYPE_CHECKING:
35
+ from transformers import PreTrainedModel
36
+
37
+ from ....modeling_config import RBLNConfig
38
+
34
39
  logger = logging.getLogger(__name__)
35
40
 
36
41
 
@@ -46,7 +51,8 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
46
51
  """
47
52
 
48
53
  @classmethod
49
- def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
54
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
50
56
  return LlamaWrapper(model, rbln_max_seq_len).eval()
51
57
 
52
58
  def __getattr__(self, __name: str) -> Any:
@@ -23,11 +23,9 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from typing import TYPE_CHECKING, Any, Callable, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Callable
27
27
 
28
- from transformers import PretrainedConfig, PreTrainedModel
29
-
30
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
28
+ from ....modeling_config import RBLNConfig
31
29
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
32
30
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
33
31
  from .midm_architecture import (
@@ -38,10 +36,7 @@ from .midm_architecture import (
38
36
  logger = logging.getLogger(__name__)
39
37
  if TYPE_CHECKING:
40
38
  from transformers import (
41
- AutoFeatureExtractor,
42
- AutoProcessor,
43
- AutoTokenizer,
44
- PretrainedConfig,
39
+ PreTrainedModel,
45
40
  )
46
41
 
47
42
 
@@ -60,7 +55,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
60
55
  """
61
56
 
62
57
  @classmethod
63
- def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
58
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
59
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
64
60
  return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
65
61
 
66
62
  def __getattr__(self, __name: str) -> Any:
@@ -77,82 +73,3 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
77
73
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
78
74
  return redirect(val)
79
75
  return val
80
-
81
- @classmethod
82
- def _get_rbln_config(
83
- cls,
84
- preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
85
- model_config: "PretrainedConfig",
86
- rbln_max_seq_len: Optional[int] = None,
87
- rbln_batch_size: Optional[int] = None,
88
- **kwargs,
89
- ) -> RBLNConfig:
90
- meta = {}
91
-
92
- prefill_chunk_size = 128
93
- if rbln_max_seq_len is None:
94
- rbln_max_seq_len = getattr(model_config, "n_positions", None)
95
- rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
96
-
97
- meta["rbln_max_seq_len"] = rbln_max_seq_len
98
- meta["rbln_batch_size"] = rbln_batch_size
99
- meta["rbln_prefill_chunk_size"] = prefill_chunk_size
100
-
101
- def get_input_info(
102
- batch_size,
103
- query_length,
104
- ):
105
- head_dim = (
106
- model_config.head_dim
107
- if hasattr(model_config, "head_dim")
108
- else model_config.hidden_size // model_config.n_head
109
- )
110
- input_info = [
111
- ("input_ids", [batch_size, query_length], "int64"),
112
- ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
113
- (
114
- "cache_position",
115
- [batch_size, query_length],
116
- "int32",
117
- ),
118
- ("batch_position", [], "int16"),
119
- ]
120
-
121
- input_info.extend(
122
- [
123
- (
124
- f"past_key_values_{i}",
125
- [
126
- rbln_batch_size,
127
- model_config.n_head,
128
- rbln_max_seq_len,
129
- head_dim,
130
- ],
131
- "float32",
132
- )
133
- for i in range(model_config.n_layer * 2)
134
- ]
135
- )
136
-
137
- return input_info
138
-
139
- prefill_input_info = get_input_info(
140
- batch_size=1,
141
- query_length=prefill_chunk_size,
142
- )
143
- dec_input_info = get_input_info(
144
- batch_size=rbln_batch_size,
145
- query_length=1,
146
- )
147
-
148
- prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
149
- dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
150
-
151
- dec_rbln_runtime_config.batch_size = rbln_batch_size
152
-
153
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
154
- [prefill_rbln_runtime_config, dec_rbln_runtime_config],
155
- _rbln_meta=meta,
156
- )
157
-
158
- return rbln_config
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_mistral import RBLNMistralForCausalLM
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
+
27
+
28
+ class MistralForCausalLMWrapper(DecoderOnlyWrapper):
29
+ pass
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Callable
27
+
28
+ from transformers import MistralForCausalLM
29
+
30
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
+ from .mistral_architecture import MistralForCausalLMWrapper
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from transformers import PreTrainedModel
36
+
37
+ from ....modeling_config import RBLNConfig
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
44
+ """
45
+ The Llama Model transformer with a language modeling head (linear layer) on top.
46
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
47
+
48
+ A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
49
+ It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
52
+ """
53
+
54
+ @classmethod
55
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
+ rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
57
+ return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
58
+
59
+ def __getattr__(self, __name: str) -> Any:
60
+ def redirect(func):
61
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
62
+
63
+ val = getattr(MistralForCausalLM, __name)
64
+
65
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
66
+ return redirect(val)
67
+
68
+ return val
@@ -70,7 +70,7 @@ class RBLNWav2Vec2ForCTC(RBLNModel):
70
70
  auto_model_class = AutoModelForMaskedLM
71
71
 
72
72
  @classmethod
73
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
73
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
74
74
  return _Wav2Vec2(model).eval()
75
75
 
76
76
  @classmethod
@@ -36,6 +36,7 @@ logger = logging.getLogger(__name__)
36
36
  if TYPE_CHECKING:
37
37
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
38
 
39
+
39
40
  class RBLNXLMRobertaModel(RBLNModel):
40
41
  auto_model_class = AutoModel # feature extraction
41
42
  original_model_class = XLMRobertaModel
@@ -81,7 +82,6 @@ class RBLNXLMRobertaModel(RBLNModel):
81
82
  rbln_model_input_names: Optional[List[str]] = None,
82
83
  rbln_batch_size: Optional[int] = None,
83
84
  ) -> RBLNConfig:
84
-
85
85
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
86
86
  model_config, "max_position_embeddings", None
87
87
  )
@@ -118,7 +118,13 @@ class RBLNXLMRobertaModel(RBLNModel):
118
118
 
119
119
  return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
120
120
 
121
- def forward(self, input_ids: "torch.Tensor", attention_mask: "torch.Tensor", token_type_ids: "torch.Tensor" = None, **kwargs):
121
+ def forward(
122
+ self,
123
+ input_ids: "torch.Tensor",
124
+ attention_mask: "torch.Tensor",
125
+ token_type_ids: "torch.Tensor" = None,
126
+ **kwargs,
127
+ ):
122
128
  if token_type_ids is None:
123
129
  token_type_ids = torch.zeros_like(input=input_ids, dtype=torch.int64)
124
130
  output = super().forward(input_ids, attention_mask, token_type_ids)
File without changes
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ from typing import Any, List
26
+
27
+ import torch
28
+ from torch.nn import Linear, Parameter
29
+ from torch.nn import functional as F
30
+
31
+
32
+ QUANTIZED_WEIGHTS = [
33
+ "q_proj",
34
+ "k_proj",
35
+ "v_proj",
36
+ "o_proj",
37
+ "gate_proj",
38
+ "up_proj",
39
+ "down_proj",
40
+ ]
41
+
42
+
43
+ def replace_quantized_linear_layers(
44
+ module: torch.nn.Module,
45
+ ) -> None:
46
+ """Replace target(quantized) linear layer's forward to qlinear forward
47
+
48
+ Args:
49
+ module (torch.nn.Module): The module containing the linear layers to be replaced.
50
+ For example, this could be an instance of a model like
51
+ LlamaForCausalLM().
52
+ """
53
+ processed_names: List[str] = []
54
+
55
+ for name, layer in module.named_modules():
56
+ is_replace_linear = name.split(".")[-1] in QUANTIZED_WEIGHTS
57
+ if isinstance(layer, torch.nn.Linear) and is_replace_linear:
58
+ *parent_address, child_name = name.split(".")
59
+ parent = access_attribute(module, parent_address)
60
+ setattr(parent, child_name, get_qlinear(layer))
61
+ processed_names.append(name)
62
+ names_repr = ", ".join(processed_names)
63
+ print(f"Replace the following linear layers as qlinear layer:\n {{{names_repr}}}")
64
+
65
+
66
+ def access_attribute(obj: Any, tokens: List[str]) -> Any:
67
+ """Get attribute of given object.
68
+
69
+ Args:
70
+ obj: object
71
+
72
+ tokens (List[str]): attribute names to access, must be in correct order
73
+
74
+ Returns:
75
+ Any: accessed attribute
76
+
77
+ Raises:
78
+ AttributeError: If attribute doesn't exists
79
+ """
80
+ if len(tokens) == 0:
81
+ return obj
82
+ return access_attribute(getattr(obj, tokens[0]), tokens[1:])
83
+
84
+
85
+ def get_qlinear(layer: Linear):
86
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
87
+ """Perform weight-only quantized linear layer.
88
+
89
+ Forward workflow:
90
+ - cast weight to high precision
91
+ - multiply scale factor to weight
92
+ - call torch.nn.functional linear
93
+ Note:
94
+ - Please don't modify following workflow
95
+ - if the workflow must be changed please contact Rebellions
96
+ """
97
+ if inputs.dtype != self.scales.dtype:
98
+ raise TypeError(f"Expected tensor of dtype {self.scales.dtype} but got {inputs.dtype}")
99
+ w_fp = self.weight.type(inputs.dtype)
100
+ w_fp *= self.scales.view(-1, 1)
101
+ return F.linear(inputs, w_fp, self.bias)
102
+
103
+ keep = layer.weight.to(torch.int8)
104
+ layer.weight = None
105
+ del layer.weight
106
+ layer.weight = Parameter(keep, requires_grad=False)
107
+ layer.scales = Parameter(torch.ones(layer.out_features, dtype=torch.float32), requires_grad=False)
108
+ layer.forward = lambda *args, **kwargs: forward(layer, *args, **kwargs)
109
+ return layer
@@ -53,8 +53,7 @@ def is_rbln_available() -> bool:
53
53
 
54
54
 
55
55
  def check_version_compats() -> None:
56
- warnings.filterwarnings(action="always", category=ImportWarning)
57
-
56
+ warnings.filterwarnings(action="always", category=ImportWarning, module="optimum.rbln")
58
57
  my_version = importlib.metadata.version("optimum-rbln")
59
58
  target_version = list(filter(lambda v: Version(my_version) > Version(v), RBLN_VERSION_COMPATS.keys()))[0]
60
59
  for compat in RBLN_VERSION_COMPATS[target_version]:
@@ -70,5 +69,3 @@ def check_version_compats() -> None:
70
69
  "Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
71
70
  ImportWarning,
72
71
  )
73
-
74
- warnings.resetwarnings()
@@ -42,8 +42,9 @@ class RBLNPytorchRuntime:
42
42
  return self.forward(*args, **kwds)
43
43
 
44
44
  def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
45
+ # filtering uselss args or kwarg such as None.
45
46
  args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
46
- kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor), kwargs.items()))
47
+ kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
47
48
  output = self.runtime(*args, **kwargs)
48
49
  return output
49
50
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: optimum-rbln
3
- Version: 0.1.8
3
+ Version: 0.1.9
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators.
5
5
  It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
6
6
  Keywords: transformers,diffusers,inference,rbln,atom,rebel
@@ -21,10 +21,12 @@ Project-URL: Homepage, https://rebellions.ai
21
21
  Project-URL: Documentation, https://docs.rbln.ai
22
22
  Requires-Python: <3.11,>=3.8
23
23
  Requires-Dist: torch<=2.2.1
24
+ Requires-Dist: torchvision<=0.17.1
25
+ Requires-Dist: torchaudio<=2.2.1
24
26
  Requires-Dist: optimum<=1.20.0
25
27
  Requires-Dist: accelerate>=0.28.0
26
- Requires-Dist: transformers<=4.40.2
27
- Requires-Dist: diffusers<=0.29.2
28
+ Requires-Dist: transformers<=4.40.2,>=4.38.0
29
+ Requires-Dist: diffusers<=0.30.1
28
30
  Requires-Dist: einops>=0.8.0
29
31
  Requires-Dist: packaging>=24.1
30
32
  Requires-Dist: pytest>=8.1.1; extra == "tests"
@@ -99,6 +101,11 @@ To install optional dependencies from all groups, specify `-G:all` option.
99
101
  pdm install -G:all
100
102
  ```
101
103
 
104
+ If you want to install optimum-rbln as [editable mode](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs) in existing venv,
105
+ ```bash
106
+ (venv) pip install -e .
107
+ ```
108
+
102
109
  ## How to use it?
103
110
 
104
111
  ### Quick Start