optimum-rbln 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. optimum/rbln/__init__.py +40 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +39 -32
  4. optimum/rbln/diffusers/models/controlnet.py +60 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +43 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  15. optimum/rbln/modeling_alias.py +8 -4
  16. optimum/rbln/modeling_base.py +512 -238
  17. optimum/rbln/modeling_config.py +152 -77
  18. optimum/rbln/modeling_seq2seq.py +166 -77
  19. optimum/rbln/transformers/__init__.py +37 -1
  20. optimum/rbln/transformers/models/__init__.py +21 -1
  21. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  22. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  23. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  24. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  25. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  26. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  27. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +128 -26
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +32 -7
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +406 -104
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  34. optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
  35. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
  36. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  37. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
  38. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
  39. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  40. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  41. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  42. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
  43. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  44. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  45. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  46. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  49. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  50. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +18 -12
  51. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  53. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  54. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +25 -16
  55. optimum/rbln/transformers/utils/__init__.py +0 -0
  56. optimum/rbln/transformers/utils/rbln_quantization.py +97 -0
  57. optimum/rbln/utils/import_utils.py +37 -5
  58. optimum/rbln/utils/logging.py +82 -0
  59. optimum/rbln/utils/runtime_utils.py +35 -1
  60. optimum/rbln/utils/timer_utils.py +19 -0
  61. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +15 -7
  62. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  63. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  64. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  65. optimum_rbln-0.1.8.dist-info/RECORD +0 -73
  66. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
+
28
+ from transformers import AutoModel, BartConfig, BartModel, PretrainedConfig
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
+
39
+
40
+ class RBLNBartModel(RBLNModel):
41
+ auto_model_class = AutoModel # feature extraction
42
+ original_model_class = BartModel
43
+ original_config_class = BartConfig
44
+
45
+ @classmethod
46
+ def _get_rbln_config(
47
+ cls,
48
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
49
+ model_config: Optional["PretrainedConfig"] = None,
50
+ rbln_kwargs: Dict[str, Any] = {},
51
+ ) -> RBLNConfig:
52
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
53
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
54
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
55
+
56
+ max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
57
+
58
+ if rbln_max_seq_len is None:
59
+ rbln_max_seq_len = max_position_embeddings
60
+ if rbln_max_seq_len is None:
61
+ for tokenizer in preprocessors:
62
+ if hasattr(tokenizer, "model_max_length"):
63
+ rbln_max_seq_len = tokenizer.model_max_length
64
+ break
65
+ if rbln_max_seq_len is None:
66
+ raise ValueError("`rbln_max_seq_len` should be specified!")
67
+
68
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
69
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
70
+
71
+ if rbln_model_input_names is None:
72
+ for tokenizer in preprocessors:
73
+ if hasattr(tokenizer, "model_input_names"):
74
+ rbln_model_input_names = tokenizer.model_input_names
75
+ # BartModel's forward() does not take token_type_ids as input.
76
+ # (Added because some of the tokenizers includes 'token_type_ids')
77
+ if "token_type_ids" in rbln_model_input_names:
78
+ rbln_model_input_names.remove("token_type_ids")
79
+ break
80
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
81
+ rbln_model_input_names = cls.rbln_model_input_names
82
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
83
+ input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
84
+ raise ValueError(
85
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
86
+ f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
87
+ )
88
+
89
+ if rbln_batch_size is None:
90
+ rbln_batch_size = 1
91
+
92
+ input_info = [
93
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
94
+ for model_input_name in rbln_model_input_names
95
+ ]
96
+
97
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
98
+
99
+ rbln_config = RBLNConfig(
100
+ rbln_cls=cls.__name__,
101
+ compile_cfgs=[rbln_compile_config],
102
+ rbln_kwargs=rbln_kwargs,
103
+ )
104
+
105
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
106
+ return rbln_config
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_bert import RBLNBertModel
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
+
28
+ from transformers import AutoModel, BertConfig, BertModel, PretrainedConfig
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
+
39
+
40
+ class RBLNBertModel(RBLNModel):
41
+ auto_model_class = AutoModel # feature extraction
42
+ original_model_class = BertModel
43
+ original_config_class = BertConfig
44
+
45
+ @classmethod
46
+ def _get_rbln_config(
47
+ cls,
48
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
49
+ model_config: Optional["PretrainedConfig"] = None,
50
+ rbln_kwargs: Dict[str, Any] = {},
51
+ ) -> RBLNConfig:
52
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
53
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
54
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
55
+
56
+ max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
57
+
58
+ if rbln_max_seq_len is None:
59
+ rbln_max_seq_len = max_position_embeddings
60
+ if rbln_max_seq_len is None:
61
+ for tokenizer in preprocessors:
62
+ if hasattr(tokenizer, "model_max_length"):
63
+ rbln_max_seq_len = tokenizer.model_max_length
64
+ break
65
+ if rbln_max_seq_len is None:
66
+ raise ValueError("`rbln_max_seq_len` should be specified!")
67
+
68
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
69
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
70
+
71
+ if rbln_model_input_names is None:
72
+ for tokenizer in preprocessors:
73
+ if hasattr(tokenizer, "model_input_names"):
74
+ rbln_model_input_names = tokenizer.model_input_names
75
+ break
76
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
77
+ rbln_model_input_names = cls.rbln_model_input_names
78
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
79
+ input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
80
+ raise ValueError(
81
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
82
+ f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
83
+ )
84
+
85
+ if rbln_batch_size is None:
86
+ rbln_batch_size = 1
87
+
88
+ input_info = [
89
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
90
+ for model_input_name in rbln_model_input_names
91
+ ]
92
+
93
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
94
+
95
+ rbln_config = RBLNConfig(
96
+ rbln_cls=cls.__name__,
97
+ compile_cfgs=[rbln_compile_config],
98
+ rbln_kwargs=rbln_kwargs,
99
+ )
100
+
101
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
102
+ return rbln_config
@@ -21,4 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
24
+ from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
@@ -22,14 +22,23 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
- from transformers import AutoConfig, AutoModel, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection
28
+ from transformers import (
29
+ AutoConfig,
30
+ AutoModel,
31
+ CLIPTextConfig,
32
+ CLIPTextModel,
33
+ CLIPTextModelWithProjection,
34
+ CLIPVisionConfig,
35
+ CLIPVisionModel,
36
+ )
37
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
29
38
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput
30
39
 
31
40
  from ....modeling_base import RBLNModel
32
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
41
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
33
42
 
34
43
 
35
44
  logger = logging.getLogger(__name__)
@@ -41,12 +50,10 @@ if TYPE_CHECKING:
41
50
  class _TextEncoder(torch.nn.Module):
42
51
  def __init__(self, enc: "CLIPTextModel"):
43
52
  super().__init__()
44
- enc.config.return_dict = False
45
- enc.config.output_hidden_states = True
46
53
  self.enc = enc
47
54
 
48
55
  def forward(self, inp):
49
- enc_out = self.enc(inp)
56
+ enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
50
57
  return enc_out
51
58
 
52
59
 
@@ -55,9 +62,6 @@ class RBLNCLIPTextModel(RBLNModel):
55
62
  original_model_class = CLIPTextModel
56
63
  original_config_class = CLIPTextConfig
57
64
 
58
- def __post_init__(self, **kwargs):
59
- self.dtype = torch.float32
60
-
61
65
  @classmethod
62
66
  def from_pretrained(cls, *args, **kwargs):
63
67
  configtmp = AutoConfig.from_pretrained
@@ -70,7 +74,7 @@ class RBLNCLIPTextModel(RBLNModel):
70
74
  return rt
71
75
 
72
76
  @classmethod
73
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
77
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
74
78
  return _TextEncoder(model).eval()
75
79
 
76
80
  @classmethod
@@ -78,28 +82,32 @@ class RBLNCLIPTextModel(RBLNModel):
78
82
  cls,
79
83
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
80
84
  model_config: "CLIPTextConfig",
85
+ rbln_kwargs: Dict[str, Any] = {},
81
86
  rbln_batch_size: Optional[int] = None,
82
- rbln_img_width: Optional[int] = None,
83
- rbln_img_height: Optional[int] = None,
84
87
  ) -> RBLNConfig:
85
- model_config.return_dict = False
88
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
86
89
  if rbln_batch_size is None:
87
90
  rbln_batch_size = 1
88
91
 
89
- rbln_runtime_config = RBLNRuntimeConfig(
90
- input_info=[
91
- (
92
- "input_ids",
93
- [
94
- rbln_batch_size,
95
- model_config.max_position_embeddings,
96
- ],
97
- "int64",
98
- ),
99
- ],
100
- )
92
+ model_config.return_dict = False
101
93
 
102
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
94
+ input_info = [
95
+ (
96
+ "input_ids",
97
+ [
98
+ rbln_batch_size,
99
+ model_config.max_position_embeddings,
100
+ ],
101
+ "int64",
102
+ ),
103
+ ]
104
+
105
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
106
+ rbln_config = RBLNConfig(
107
+ rbln_cls=cls.__name__,
108
+ compile_cfgs=[rbln_compile_config],
109
+ rbln_kwargs=rbln_kwargs,
110
+ )
103
111
  return rbln_config
104
112
 
105
113
  def forward(self, input_ids: "torch.Tensor", **kwargs):
@@ -113,3 +121,97 @@ class RBLNCLIPTextModel(RBLNModel):
113
121
 
114
122
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
115
123
  original_model_class = CLIPTextModelWithProjection
124
+
125
+
126
+ class _VisionEncoder(torch.nn.Module):
127
+ def __init__(self, enc: CLIPVisionModel):
128
+ super().__init__()
129
+ self.enc = enc
130
+
131
+ def forward(self, inp):
132
+ enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
133
+ return enc_out
134
+
135
+
136
+ class RBLNCLIPVisionModel(RBLNModel):
137
+ original_model_class = CLIPVisionModel
138
+ original_config_class = CLIPVisionConfig
139
+
140
+ @classmethod
141
+ def from_pretrained(cls, *args, **kwargs):
142
+ configtmp = AutoConfig.from_pretrained
143
+ modeltmp = AutoModel.from_pretrained
144
+ AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
145
+ AutoModel.from_pretrained = cls.original_model_class.from_pretrained
146
+ rt = super().from_pretrained(*args, **kwargs)
147
+ AutoConfig.from_pretrained = configtmp
148
+ AutoModel.from_pretrained = modeltmp
149
+ return rt
150
+
151
+ @classmethod
152
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
153
+ return _VisionEncoder(model).eval()
154
+
155
+ @classmethod
156
+ def _get_rbln_config(
157
+ cls,
158
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
159
+ model_config: "CLIPTextConfig",
160
+ rbln_kwargs: Dict[str, Any] = {},
161
+ ) -> RBLNConfig:
162
+ rbln_batch_size = rbln_kwargs.get("batch_size", 1)
163
+ rbln_image_size = rbln_kwargs.get("image_size", None)
164
+
165
+ if rbln_image_size is None:
166
+ rbln_image_size = getattr(model_config, "image_size", None)
167
+
168
+ if isinstance(rbln_image_size, int):
169
+ rbln_image_size = (rbln_image_size, rbln_image_size)
170
+
171
+ if rbln_image_size is None:
172
+ raise ValueError("`rbln_image_size` should be specified!")
173
+
174
+ rbln_compile_config = RBLNCompileConfig(
175
+ input_info=[
176
+ (
177
+ "pixel_values",
178
+ [
179
+ rbln_batch_size,
180
+ 3,
181
+ rbln_image_size[0],
182
+ rbln_image_size[1],
183
+ ],
184
+ "float32",
185
+ )
186
+ ]
187
+ )
188
+
189
+ rbln_config = RBLNConfig(
190
+ rbln_cls=cls.__name__,
191
+ compile_cfgs=[rbln_compile_config],
192
+ rbln_kwargs=rbln_kwargs,
193
+ )
194
+
195
+ rbln_config.model_cfg.update(
196
+ {
197
+ "batch_size": rbln_batch_size,
198
+ "image_size": rbln_image_size,
199
+ }
200
+ )
201
+
202
+ return rbln_config
203
+
204
+ def forward(
205
+ self,
206
+ pixel_values: Optional[torch.FloatTensor] = None,
207
+ **kwargs,
208
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
209
+ if len(kwargs) > 0 and any(kwargs.values()):
210
+ logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
211
+
212
+ output = super().forward(pixel_values)
213
+ return BaseModelOutputWithPooling(
214
+ last_hidden_state=output[0],
215
+ pooler_output=output[1],
216
+ hidden_states=output[2:],
217
+ )
@@ -49,18 +49,19 @@ class DecoderOnlyWrapper(torch.nn.Module):
49
49
  self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
50
50
  )
51
51
  self.max_seq_len = max_seq_len
52
+ self.rope_scaling = getattr(self.config, "rope_scaling", None)
52
53
  self.rotary_emb = self._init_rope()
53
54
 
54
55
  def _init_rope(self):
55
- if self.config.rope_scaling is None:
56
+ if self.rope_scaling is None:
56
57
  rotary_emb = RotaryEmbedding(
57
58
  self.head_dim,
58
59
  max_position_embeddings=self.max_position_embeddings,
59
60
  base=self.config.rope_theta,
60
61
  )
61
62
  else:
62
- scaling_type = self.config.rope_scaling["type"]
63
- scaling_factor = self.config.rope_scaling["factor"]
63
+ scaling_type = self.rope_scaling["type"]
64
+ scaling_factor = self.rope_scaling["factor"]
64
65
  if scaling_type == "linear":
65
66
  rotary_emb = LinearScalingRotaryEmbedding(
66
67
  self.head_dim,
@@ -92,17 +93,29 @@ class DecoderOnlyWrapper(torch.nn.Module):
92
93
 
93
94
  def forward(
94
95
  self,
95
- input_ids,
96
+ input_ids_or_inputs_embeds,
96
97
  attention_mask,
97
98
  cache_position,
98
99
  batch_position,
100
+ query_idx,
99
101
  *past_key_values,
100
102
  ):
101
- if input_ids.shape[1] == 1:
103
+ if input_ids_or_inputs_embeds.shape[1] == 1:
102
104
  rbln_batch_position = None
103
105
  else:
104
106
  rbln_batch_position = batch_position
105
107
 
108
+ if input_ids_or_inputs_embeds.ndim == 2:
109
+ # input_ids
110
+ input_ids = input_ids_or_inputs_embeds
111
+ inputs_embeds = None
112
+ elif input_ids_or_inputs_embeds.ndim == 3:
113
+ # inputs_embeds
114
+ input_ids = None
115
+ inputs_embeds = input_ids_or_inputs_embeds
116
+ else:
117
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
118
+
106
119
  # Formatting list of past_kv to DynamicCache class.
107
120
  past_key_values = RebelDynamicCache.from_input_format(
108
121
  cache_position,
@@ -114,6 +127,7 @@ class DecoderOnlyWrapper(torch.nn.Module):
114
127
  outputs = forward_dict["wrapper"](
115
128
  self.model,
116
129
  input_ids=input_ids,
130
+ inputs_embeds=inputs_embeds,
117
131
  attention_mask=attention_mask,
118
132
  position_ids=cache_position,
119
133
  past_key_values=past_key_values,
@@ -123,11 +137,14 @@ class DecoderOnlyWrapper(torch.nn.Module):
123
137
  )
124
138
 
125
139
  hidden_states = outputs[0]
140
+ if batch_position >= 0:
141
+ hidden_states = hidden_states[:, query_idx].unsqueeze(1)
142
+
126
143
  logits = self.lm_head(hidden_states)
127
144
 
128
145
  output = (logits,) + outputs[1:]
129
146
 
130
- return output, batch_position
147
+ return output, batch_position + query_idx
131
148
 
132
149
 
133
150
  class DecoderOnlyAttention:
@@ -322,8 +339,16 @@ class DecoderOnlyModel:
322
339
  forward_dict: Optional[Dict[str, classmethod]] = None,
323
340
  rotary_pos_emb=None,
324
341
  ) -> BaseModelOutputWithPast:
342
+ # retrieve input_ids and inputs_embeds
343
+ if (input_ids is None) ^ (inputs_embeds is not None):
344
+ raise ValueError(
345
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
346
+ )
347
+
325
348
  # embed positions
326
- inputs_embeds = self.embed_tokens(input_ids)
349
+ if inputs_embeds is None:
350
+ inputs_embeds = self.embed_tokens(input_ids)
351
+
327
352
  hidden_states = inputs_embeds
328
353
  attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
329
354