optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,106 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
+
28
+ from transformers import AutoModel, BartConfig, BartModel, PretrainedConfig
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
+
39
+
40
+ class RBLNBartModel(RBLNModel):
41
+ auto_model_class = AutoModel # feature extraction
42
+ original_model_class = BartModel
43
+ original_config_class = BartConfig
44
+
45
+ @classmethod
46
+ def _get_rbln_config(
47
+ cls,
48
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
49
+ model_config: Optional["PretrainedConfig"] = None,
50
+ rbln_kwargs: Dict[str, Any] = {},
51
+ ) -> RBLNConfig:
52
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
53
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
54
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
55
+
56
+ max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
57
+
58
+ if rbln_max_seq_len is None:
59
+ rbln_max_seq_len = max_position_embeddings
60
+ if rbln_max_seq_len is None:
61
+ for tokenizer in preprocessors:
62
+ if hasattr(tokenizer, "model_max_length"):
63
+ rbln_max_seq_len = tokenizer.model_max_length
64
+ break
65
+ if rbln_max_seq_len is None:
66
+ raise ValueError("`rbln_max_seq_len` should be specified!")
67
+
68
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
69
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
70
+
71
+ if rbln_model_input_names is None:
72
+ for tokenizer in preprocessors:
73
+ if hasattr(tokenizer, "model_input_names"):
74
+ rbln_model_input_names = tokenizer.model_input_names
75
+ # BartModel's forward() does not take token_type_ids as input.
76
+ # (Added because some of the tokenizers includes 'token_type_ids')
77
+ if "token_type_ids" in rbln_model_input_names:
78
+ rbln_model_input_names.remove("token_type_ids")
79
+ break
80
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
81
+ rbln_model_input_names = cls.rbln_model_input_names
82
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
83
+ input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
84
+ raise ValueError(
85
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
86
+ f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
87
+ )
88
+
89
+ if rbln_batch_size is None:
90
+ rbln_batch_size = 1
91
+
92
+ input_info = [
93
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
94
+ for model_input_name in rbln_model_input_names
95
+ ]
96
+
97
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
98
+
99
+ rbln_config = RBLNConfig(
100
+ rbln_cls=cls.__name__,
101
+ compile_cfgs=[rbln_compile_config],
102
+ rbln_kwargs=rbln_kwargs,
103
+ )
104
+
105
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
106
+ return rbln_config
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_bert import RBLNBertModel
@@ -0,0 +1,102 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
+
28
+ from transformers import AutoModel, BertConfig, BertModel, PretrainedConfig
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
38
+
39
+
40
+ class RBLNBertModel(RBLNModel):
41
+ auto_model_class = AutoModel # feature extraction
42
+ original_model_class = BertModel
43
+ original_config_class = BertConfig
44
+
45
+ @classmethod
46
+ def _get_rbln_config(
47
+ cls,
48
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
49
+ model_config: Optional["PretrainedConfig"] = None,
50
+ rbln_kwargs: Dict[str, Any] = {},
51
+ ) -> RBLNConfig:
52
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
53
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
54
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
55
+
56
+ max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
57
+
58
+ if rbln_max_seq_len is None:
59
+ rbln_max_seq_len = max_position_embeddings
60
+ if rbln_max_seq_len is None:
61
+ for tokenizer in preprocessors:
62
+ if hasattr(tokenizer, "model_max_length"):
63
+ rbln_max_seq_len = tokenizer.model_max_length
64
+ break
65
+ if rbln_max_seq_len is None:
66
+ raise ValueError("`rbln_max_seq_len` should be specified!")
67
+
68
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
69
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
70
+
71
+ if rbln_model_input_names is None:
72
+ for tokenizer in preprocessors:
73
+ if hasattr(tokenizer, "model_input_names"):
74
+ rbln_model_input_names = tokenizer.model_input_names
75
+ break
76
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
77
+ rbln_model_input_names = cls.rbln_model_input_names
78
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
79
+ input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
80
+ raise ValueError(
81
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
82
+ f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
83
+ )
84
+
85
+ if rbln_batch_size is None:
86
+ rbln_batch_size = 1
87
+
88
+ input_info = [
89
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
90
+ for model_input_name in rbln_model_input_names
91
+ ]
92
+
93
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
94
+
95
+ rbln_config = RBLNConfig(
96
+ rbln_cls=cls.__name__,
97
+ compile_cfgs=[rbln_compile_config],
98
+ rbln_kwargs=rbln_kwargs,
99
+ )
100
+
101
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
102
+ return rbln_config
@@ -21,4 +21,4 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
24
+ from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
@@ -22,14 +22,23 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Optional, Union
25
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
- from transformers import AutoConfig, AutoModel, CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection
28
+ from transformers import (
29
+ AutoConfig,
30
+ AutoModel,
31
+ CLIPTextConfig,
32
+ CLIPTextModel,
33
+ CLIPTextModelWithProjection,
34
+ CLIPVisionConfig,
35
+ CLIPVisionModel,
36
+ )
37
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
29
38
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput
30
39
 
31
40
  from ....modeling_base import RBLNModel
32
- from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
41
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
33
42
 
34
43
 
35
44
  logger = logging.getLogger(__name__)
@@ -41,12 +50,10 @@ if TYPE_CHECKING:
41
50
  class _TextEncoder(torch.nn.Module):
42
51
  def __init__(self, enc: "CLIPTextModel"):
43
52
  super().__init__()
44
- enc.config.return_dict = False
45
- enc.config.output_hidden_states = True
46
53
  self.enc = enc
47
54
 
48
55
  def forward(self, inp):
49
- enc_out = self.enc(inp)
56
+ enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
50
57
  return enc_out
51
58
 
52
59
 
@@ -55,9 +62,6 @@ class RBLNCLIPTextModel(RBLNModel):
55
62
  original_model_class = CLIPTextModel
56
63
  original_config_class = CLIPTextConfig
57
64
 
58
- def __post_init__(self, **kwargs):
59
- self.dtype = torch.float32
60
-
61
65
  @classmethod
62
66
  def from_pretrained(cls, *args, **kwargs):
63
67
  configtmp = AutoConfig.from_pretrained
@@ -78,28 +82,32 @@ class RBLNCLIPTextModel(RBLNModel):
78
82
  cls,
79
83
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
80
84
  model_config: "CLIPTextConfig",
85
+ rbln_kwargs: Dict[str, Any] = {},
81
86
  rbln_batch_size: Optional[int] = None,
82
- rbln_img_width: Optional[int] = None,
83
- rbln_img_height: Optional[int] = None,
84
87
  ) -> RBLNConfig:
85
- model_config.return_dict = False
88
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
86
89
  if rbln_batch_size is None:
87
90
  rbln_batch_size = 1
88
91
 
89
- rbln_runtime_config = RBLNRuntimeConfig(
90
- input_info=[
91
- (
92
- "input_ids",
93
- [
94
- rbln_batch_size,
95
- model_config.max_position_embeddings,
96
- ],
97
- "int64",
98
- ),
99
- ],
100
- )
92
+ model_config.return_dict = False
101
93
 
102
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
94
+ input_info = [
95
+ (
96
+ "input_ids",
97
+ [
98
+ rbln_batch_size,
99
+ model_config.max_position_embeddings,
100
+ ],
101
+ "int64",
102
+ ),
103
+ ]
104
+
105
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
106
+ rbln_config = RBLNConfig(
107
+ rbln_cls=cls.__name__,
108
+ compile_cfgs=[rbln_compile_config],
109
+ rbln_kwargs=rbln_kwargs,
110
+ )
103
111
  return rbln_config
104
112
 
105
113
  def forward(self, input_ids: "torch.Tensor", **kwargs):
@@ -113,3 +121,97 @@ class RBLNCLIPTextModel(RBLNModel):
113
121
 
114
122
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
115
123
  original_model_class = CLIPTextModelWithProjection
124
+
125
+
126
+ class _VisionEncoder(torch.nn.Module):
127
+ def __init__(self, enc: CLIPVisionModel):
128
+ super().__init__()
129
+ self.enc = enc
130
+
131
+ def forward(self, inp):
132
+ enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
133
+ return enc_out
134
+
135
+
136
+ class RBLNCLIPVisionModel(RBLNModel):
137
+ original_model_class = CLIPVisionModel
138
+ original_config_class = CLIPVisionConfig
139
+
140
+ @classmethod
141
+ def from_pretrained(cls, *args, **kwargs):
142
+ configtmp = AutoConfig.from_pretrained
143
+ modeltmp = AutoModel.from_pretrained
144
+ AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
145
+ AutoModel.from_pretrained = cls.original_model_class.from_pretrained
146
+ rt = super().from_pretrained(*args, **kwargs)
147
+ AutoConfig.from_pretrained = configtmp
148
+ AutoModel.from_pretrained = modeltmp
149
+ return rt
150
+
151
+ @classmethod
152
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
153
+ return _VisionEncoder(model).eval()
154
+
155
+ @classmethod
156
+ def _get_rbln_config(
157
+ cls,
158
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
159
+ model_config: "CLIPTextConfig",
160
+ rbln_kwargs: Dict[str, Any] = {},
161
+ ) -> RBLNConfig:
162
+ rbln_batch_size = rbln_kwargs.get("batch_size", 1)
163
+ rbln_image_size = rbln_kwargs.get("image_size", None)
164
+
165
+ if rbln_image_size is None:
166
+ rbln_image_size = getattr(model_config, "image_size", None)
167
+
168
+ if isinstance(rbln_image_size, int):
169
+ rbln_image_size = (rbln_image_size, rbln_image_size)
170
+
171
+ if rbln_image_size is None:
172
+ raise ValueError("`rbln_image_size` should be specified!")
173
+
174
+ rbln_compile_config = RBLNCompileConfig(
175
+ input_info=[
176
+ (
177
+ "pixel_values",
178
+ [
179
+ rbln_batch_size,
180
+ 3,
181
+ rbln_image_size[0],
182
+ rbln_image_size[1],
183
+ ],
184
+ "float32",
185
+ )
186
+ ]
187
+ )
188
+
189
+ rbln_config = RBLNConfig(
190
+ rbln_cls=cls.__name__,
191
+ compile_cfgs=[rbln_compile_config],
192
+ rbln_kwargs=rbln_kwargs,
193
+ )
194
+
195
+ rbln_config.model_cfg.update(
196
+ {
197
+ "batch_size": rbln_batch_size,
198
+ "image_size": rbln_image_size,
199
+ }
200
+ )
201
+
202
+ return rbln_config
203
+
204
+ def forward(
205
+ self,
206
+ pixel_values: Optional[torch.FloatTensor] = None,
207
+ **kwargs,
208
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
209
+ if len(kwargs) > 0 and any(kwargs.values()):
210
+ logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
211
+
212
+ output = super().forward(pixel_values)
213
+ return BaseModelOutputWithPooling(
214
+ last_hidden_state=output[0],
215
+ pooler_output=output[1],
216
+ hidden_states=output[2:],
217
+ )
@@ -93,17 +93,29 @@ class DecoderOnlyWrapper(torch.nn.Module):
93
93
 
94
94
  def forward(
95
95
  self,
96
- input_ids,
96
+ input_ids_or_inputs_embeds,
97
97
  attention_mask,
98
98
  cache_position,
99
99
  batch_position,
100
+ query_idx,
100
101
  *past_key_values,
101
102
  ):
102
- if input_ids.shape[1] == 1:
103
+ if input_ids_or_inputs_embeds.shape[1] == 1:
103
104
  rbln_batch_position = None
104
105
  else:
105
106
  rbln_batch_position = batch_position
106
107
 
108
+ if input_ids_or_inputs_embeds.ndim == 2:
109
+ # input_ids
110
+ input_ids = input_ids_or_inputs_embeds
111
+ inputs_embeds = None
112
+ elif input_ids_or_inputs_embeds.ndim == 3:
113
+ # inputs_embeds
114
+ input_ids = None
115
+ inputs_embeds = input_ids_or_inputs_embeds
116
+ else:
117
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
118
+
107
119
  # Formatting list of past_kv to DynamicCache class.
108
120
  past_key_values = RebelDynamicCache.from_input_format(
109
121
  cache_position,
@@ -115,6 +127,7 @@ class DecoderOnlyWrapper(torch.nn.Module):
115
127
  outputs = forward_dict["wrapper"](
116
128
  self.model,
117
129
  input_ids=input_ids,
130
+ inputs_embeds=inputs_embeds,
118
131
  attention_mask=attention_mask,
119
132
  position_ids=cache_position,
120
133
  past_key_values=past_key_values,
@@ -124,11 +137,14 @@ class DecoderOnlyWrapper(torch.nn.Module):
124
137
  )
125
138
 
126
139
  hidden_states = outputs[0]
140
+ if batch_position >= 0:
141
+ hidden_states = hidden_states[:, query_idx].unsqueeze(1)
142
+
127
143
  logits = self.lm_head(hidden_states)
128
144
 
129
145
  output = (logits,) + outputs[1:]
130
146
 
131
- return output, batch_position
147
+ return output, batch_position + query_idx
132
148
 
133
149
 
134
150
  class DecoderOnlyAttention:
@@ -323,8 +339,16 @@ class DecoderOnlyModel:
323
339
  forward_dict: Optional[Dict[str, classmethod]] = None,
324
340
  rotary_pos_emb=None,
325
341
  ) -> BaseModelOutputWithPast:
342
+ # retrieve input_ids and inputs_embeds
343
+ if (input_ids is None) ^ (inputs_embeds is not None):
344
+ raise ValueError(
345
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
346
+ )
347
+
326
348
  # embed positions
327
- inputs_embeds = self.embed_tokens(input_ids)
349
+ if inputs_embeds is None:
350
+ inputs_embeds = self.embed_tokens(input_ids)
351
+
328
352
  hidden_states = inputs_embeds
329
353
  attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
330
354