optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +7 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +172 -100
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +2 -0
  15. optimum/rbln/transformers/models/__init__.py +1 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
  24. optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
  25. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  26. optimum/rbln/utils/__init__.py +1 -1
  27. optimum/rbln/utils/import_utils.py +46 -0
  28. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
  29. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
  30. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  31. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,13 +23,10 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
30
  from transformers import (
34
31
  AutoModelForSeq2SeqLM,
35
32
  BartConfig,
@@ -39,12 +36,11 @@ from transformers import (
39
36
  )
40
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
41
38
 
42
- from .modeling_base import RBLNBaseModel
39
+ from .modeling_base import RBLNModel
43
40
  from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
44
41
  from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
45
42
  from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
46
43
  from .utils.runtime_utils import RBLNPytorchRuntime
47
- from .utils.save_utils import maybe_save_preprocessors
48
44
 
49
45
 
50
46
  logger = logging.getLogger(__name__)
@@ -75,7 +71,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
75
71
  return Seq2SeqLMOutput(logits=outputs)
76
72
 
77
73
 
78
- class RBLNModelForSeq2SeqLM(RBLNBaseModel):
74
+ class RBLNModelForSeq2SeqLM(RBLNModel):
79
75
  """
80
76
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
81
77
  This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -88,7 +84,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
88
84
  Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
89
85
  """
90
86
 
91
- model_type = "rbln_model"
92
87
  auto_model_class = AutoModelForSeq2SeqLM
93
88
 
94
89
  def __post_init__(self, **kwargs):
@@ -97,8 +92,8 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
97
92
  self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
98
93
  self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
99
94
  self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
100
- self.encoder = RBLNRuntimeEncoder(runtime=self.runtimes[0], main_input_name="input_ids")
101
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
95
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
96
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
102
97
 
103
98
  def can_generate(self):
104
99
  return True
@@ -149,74 +144,18 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
149
144
  }
150
145
 
151
146
  @classmethod
152
- def _export(
153
- cls,
154
- model_id: str,
155
- config: "PretrainedConfig",
156
- use_auth_token: Optional[Union[bool, str]] = None,
157
- revision: Optional[str] = None,
158
- force_download: bool = False,
159
- cache_dir: Optional[str] = None,
160
- subfolder: str = "",
161
- local_files_only: bool = False,
162
- trust_remote_code: bool = False,
163
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
164
- **kwargs,
165
- ) -> "AutoModelForSeq2SeqLM":
166
- """
167
- Exports a vanilla Transformers model into a rbln-compiled Module.
168
- """
169
- task = kwargs.pop("task", None)
170
- if task is None:
171
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
172
-
173
- if model_save_dir is None:
174
- save_dir = TemporaryDirectory()
175
- save_dir_path = Path(save_dir.name)
176
- else:
177
- save_dir = model_save_dir
178
- if isinstance(save_dir, TemporaryDirectory):
179
- save_dir_path = Path(model_save_dir.name)
180
- else:
181
- save_dir_path = Path(model_save_dir)
182
- save_dir_path.mkdir(exist_ok=True)
183
-
147
+ def update_kwargs(cls, kwargs):
184
148
  kwargs.update(
185
149
  {
186
150
  "torchscript": True,
187
151
  "return_dict": False,
188
- "use_cache": False,
152
+ "use_cache": True,
189
153
  }
190
154
  )
155
+ return kwargs
191
156
 
192
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
193
-
194
- model: AutoModelForSeq2SeqLM = TasksManager.get_model_from_task(
195
- task=task,
196
- model_name_or_path=model_id,
197
- subfolder=subfolder,
198
- revision=revision,
199
- framework="pt",
200
- cache_dir=cache_dir,
201
- use_auth_token=use_auth_token,
202
- local_files_only=local_files_only,
203
- force_download=force_download,
204
- trust_remote_code=trust_remote_code,
205
- **kwargs,
206
- )
207
-
208
- if config is None:
209
- config = model.config
210
-
211
- config.save_pretrained(save_dir_path)
212
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
213
-
214
- # Get compilation arguments
215
- if rbln_config_kwargs.get("rbln_config", None) is None:
216
- rbln_config = cls.get_rbln_config(
217
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
218
- )
219
-
157
+ @classmethod
158
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
220
159
  def optimized_models(model):
221
160
  if isinstance(model, T5ForConditionalGeneration):
222
161
  encoder_model = T5EncoderWrapper(model).eval()
@@ -229,67 +168,54 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
229
168
 
230
169
  return encoder_model, decoder_model
231
170
 
232
- def compile():
233
- wrapped_encoder, wrapped_decoder = optimized_models(model)
171
+ wrapped_encoder, wrapped_decoder = optimized_models(model)
234
172
 
235
- wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
236
- wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
237
- wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
173
+ wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
174
+ wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
175
+ wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
238
176
 
239
- wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
240
- wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
241
- wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
177
+ wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
178
+ wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
179
+ wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
242
180
 
243
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
244
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
181
+ enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
182
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
245
183
 
246
- if isinstance(model, T5ForConditionalGeneration):
247
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
248
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
249
- else:
250
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
251
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
252
-
253
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs)
254
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
255
-
256
- enc_ir = rebel.torchscript_to_ir(
257
- enc_scripted_model,
258
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
259
- name=enc_rbln_runtime_config.rbln_mod_name,
260
- )
261
- dec_ir = rebel.torchscript_to_ir(
262
- dec_scripted_model,
263
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
264
- name=dec_rbln_runtime_config.rbln_mod_name,
265
- )
266
- dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
267
-
268
- connections = [
269
- (enc_ir.outputs[0], dec_ir.inputs[5]),
270
- (dec_ir.outputs[1], dec_ir.inputs[4]),
271
- ]
272
- compiled_model = rebel.compile(
273
- enc_ir,
274
- dec_ir,
275
- connections=connections,
276
- fusion=enc_rbln_runtime_config.fusion,
277
- npu=enc_rbln_runtime_config.npu,
278
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
279
- )
280
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
281
-
282
- compile()
283
-
284
- rbln_config.save(save_dir_path)
285
-
286
- return cls._from_pretrained(
287
- model_id=save_dir_path,
288
- config=config,
289
- model_save_dir=save_dir,
290
- **rbln_constructor_kwargs,
291
- **kwargs,
184
+ if isinstance(model, T5ForConditionalGeneration):
185
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
186
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
187
+ else:
188
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
189
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
190
+
191
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
192
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
193
+
194
+ enc_ir = rebel.torchscript_to_ir(
195
+ enc_scripted_model,
196
+ input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
197
+ name=enc_rbln_runtime_config.rbln_mod_name,
198
+ )
199
+ dec_ir = rebel.torchscript_to_ir(
200
+ dec_scripted_model,
201
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
202
+ name=dec_rbln_runtime_config.rbln_mod_name,
292
203
  )
204
+ dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
205
+
206
+ connections = [
207
+ (enc_ir.outputs[0], dec_ir.inputs[5]),
208
+ (dec_ir.outputs[1], dec_ir.inputs[4]),
209
+ ]
210
+ compiled_model = rebel.compile(
211
+ enc_ir,
212
+ dec_ir,
213
+ connections=connections,
214
+ fusion=enc_rbln_runtime_config.fusion,
215
+ npu=enc_rbln_runtime_config.npu,
216
+ tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
217
+ )
218
+ return compiled_model
293
219
 
294
220
  @classmethod
295
221
  def _get_rbln_config(
@@ -411,11 +337,14 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
411
337
 
412
338
  return rbln_config
413
339
 
414
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
340
+ @classmethod
341
+ def _create_runtimes(
342
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
343
+ ) -> List[rebel.Runtime]:
415
344
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
416
345
  return [
417
- self.compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
418
- self.compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
346
+ compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
347
+ compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
419
348
  ]
420
349
 
421
350
  def forward(
@@ -436,9 +365,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
436
365
 
437
366
  return Seq2SeqLMOutput(logits=lm_logits)
438
367
 
439
- def __repr__(self):
440
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
441
-
442
368
  def _prepare_encoder_decoder_kwargs_for_generation(
443
369
  self,
444
370
  inputs_tensor: torch.Tensor,
@@ -31,6 +31,7 @@ _import_structure = {
31
31
  "models": [
32
32
  "RBLNCLIPTextModel",
33
33
  "RBLNCLIPTextModelWithProjection",
34
+ "RBLNDPTForDepthEstimation",
34
35
  "RBLNGPT2LMHeadModel",
35
36
  "RBLNWav2Vec2ForCTC",
36
37
  "RBLNWhisperForConditionalGeneration",
@@ -44,6 +45,7 @@ if TYPE_CHECKING:
44
45
  from .models import (
45
46
  RBLNCLIPTextModel,
46
47
  RBLNCLIPTextModelWithProjection,
48
+ RBLNDPTForDepthEstimation,
47
49
  RBLNGPT2LMHeadModel,
48
50
  RBLNLlamaForCausalLM,
49
51
  RBLNMidmLMHeadModel,
@@ -22,6 +22,7 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
25
+ from .dpt import RBLNDPTForDepthEstimation
25
26
  from .gpt2 import RBLNGPT2LMHeadModel
26
27
  from .llama import RBLNLlamaForCausalLM
27
28
  from .midm import RBLNMidmLMHeadModel
@@ -51,7 +51,6 @@ class _TextEncoder(torch.nn.Module):
51
51
 
52
52
 
53
53
  class RBLNCLIPTextModel(RBLNModel):
54
- model_type = "rbln_clip"
55
54
  auto_model_class = AutoModel # feature extraction
56
55
  original_model_class = CLIPTextModel
57
56
  original_config_class = CLIPTextConfig
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_dpt import RBLNDPTForDepthEstimation
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from typing import TYPE_CHECKING, Iterable, Optional, Union
26
+
27
+ from transformers import AutoModelForDepthEstimation
28
+ from transformers.modeling_outputs import DepthEstimatorOutput
29
+
30
+ from ....modeling_base import RBLNModel
31
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
38
+
39
+
40
+ class RBLNDPTForDepthEstimation(RBLNModel):
41
+ model_type = "rbln_model"
42
+ auto_model_class = AutoModelForDepthEstimation
43
+ main_input_name = "pixel_values"
44
+
45
+ @classmethod
46
+ def _get_rbln_config(
47
+ cls,
48
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
49
+ model_config: Optional["PretrainedConfig"] = None,
50
+ rbln_image_size: Optional[int] = None,
51
+ rbln_batch_size: Optional[int] = None,
52
+ ) -> RBLNConfig:
53
+ if rbln_batch_size is None:
54
+ rbln_batch_size = 1
55
+
56
+ if rbln_image_size is None:
57
+ for processor in preprocessors:
58
+ image_size = getattr(processor, "size", None)
59
+
60
+ if image_size is not None:
61
+ if isinstance(image_size, Iterable):
62
+ if "shortest_edge" in image_size:
63
+ rbln_image_size = image_size["shortest_edge"]
64
+ break
65
+ elif "height" in image_size and "width" in image_size:
66
+ rbln_image_size = image_size["height"], image_size["width"]
67
+ break
68
+ else:
69
+ rbln_image_size = image_size
70
+
71
+ if rbln_image_size is None:
72
+ rbln_image_size = getattr(model_config, "image_size", None)
73
+
74
+ if rbln_image_size is None:
75
+ raise ValueError("`rbln_image_size` should be specified!")
76
+
77
+ if isinstance(rbln_image_size, int):
78
+ rbln_image_size = rbln_image_size, rbln_image_size
79
+
80
+ input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size[0], rbln_image_size[1]], "float32")]
81
+
82
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
83
+ meta = {"rbln_image_size": rbln_image_size}
84
+
85
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
86
+
87
+ def forward(self, *args, **kwargs):
88
+ predicted_depth = super().forward(*args, **kwargs)
89
+ return DepthEstimatorOutput(predicted_depth=predicted_depth)
@@ -21,7 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Optional, Tuple, Union
24
+ from typing import List, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -79,7 +79,7 @@ class _GPT2Attention(GPT2Attention):
79
79
  def forward(
80
80
  self,
81
81
  hidden_states: Optional[Tuple[torch.FloatTensor]],
82
- layer_past: Optional[Tuple[torch.Tensor]] = None,
82
+ past_key_values: List[List[torch.Tensor]] = None,
83
83
  attention_mask: Optional[torch.FloatTensor] = None,
84
84
  head_mask: Optional[torch.FloatTensor] = None,
85
85
  encoder_hidden_states: Optional[torch.Tensor] = None,
@@ -94,16 +94,15 @@ class _GPT2Attention(GPT2Attention):
94
94
  key = self._split_heads(key, self.num_heads, self.head_dim)
95
95
  value = self._split_heads(value, self.num_heads, self.head_dim)
96
96
 
97
- if layer_past is not None:
98
- past_key, past_value = layer_past
97
+ if past_key_values is not None:
98
+ past_key, past_value = past_key_values[self.layer_idx]
99
99
  query_length = query.shape[-2]
100
100
 
101
- key = torch.slice_scatter(past_key, key, dim=2, start=cache_position, end=cache_position + query_length)
102
- value = torch.slice_scatter(
103
- past_value, value, dim=2, start=cache_position, end=cache_position + query_length
104
- )
101
+ key = past_key.slice_scatter(key, dim=2, start=cache_position, end=cache_position + query_length)
102
+ value = past_value.slice_scatter(value, dim=2, start=cache_position, end=cache_position + query_length)
103
+
104
+ past_key_values[self.layer_idx] = [key, value]
105
105
 
106
- present = (key, value)
107
106
  attn_output, _ = _GPT2Attention._attn(self, query, key, value, attention_mask, head_mask)
108
107
 
109
108
  attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
@@ -111,16 +110,14 @@ class _GPT2Attention(GPT2Attention):
111
110
  attn_output = self.c_proj(attn_output)
112
111
  attn_output = self.resid_dropout(attn_output)
113
112
 
114
- outputs = (attn_output, present)
115
-
116
- return outputs
113
+ return attn_output
117
114
 
118
115
 
119
116
  class _GPT2Block(GPT2Block):
120
117
  def forward(
121
118
  self,
122
119
  hidden_states: Optional[Tuple[torch.FloatTensor]],
123
- layer_past: Optional[Tuple[torch.Tensor]] = None,
120
+ past_key_values: List[List[torch.Tensor]] = None,
124
121
  attention_mask: Optional[torch.FloatTensor] = None,
125
122
  head_mask: Optional[torch.FloatTensor] = None,
126
123
  encoder_hidden_states: Optional[torch.Tensor] = None,
@@ -132,17 +129,15 @@ class _GPT2Block(GPT2Block):
132
129
  residual = hidden_states
133
130
  hidden_states = self.ln_1(hidden_states)
134
131
 
135
- attn_outputs = _GPT2Attention.forward(
132
+ attn_output = _GPT2Attention.forward(
136
133
  self.attn,
137
134
  hidden_states,
138
- layer_past=layer_past,
135
+ past_key_values=past_key_values,
139
136
  attention_mask=attention_mask,
140
137
  head_mask=head_mask,
141
138
  cache_position=cache_position,
142
139
  )
143
140
 
144
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
145
- outputs = attn_outputs[1:]
146
141
  # residual connection
147
142
  hidden_states = attn_output + residual
148
143
 
@@ -152,15 +147,14 @@ class _GPT2Block(GPT2Block):
152
147
  # residual connection
153
148
  hidden_states = residual + feed_forward_hidden_states
154
149
 
155
- outputs = (hidden_states,) + outputs
156
- return outputs # hidden_states, present, (attentions, cross_attentions)
150
+ return hidden_states
157
151
 
158
152
 
159
153
  class _GPT2Model(GPT2Model):
160
154
  def forward(
161
155
  self,
162
156
  input_ids: Optional[torch.LongTensor] = None,
163
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
157
+ past_key_values: List[List[torch.Tensor]] = None,
164
158
  attention_mask: Optional[torch.FloatTensor] = None,
165
159
  position_ids: Optional[torch.LongTensor] = None,
166
160
  head_mask: Optional[torch.FloatTensor] = None,
@@ -191,23 +185,19 @@ class _GPT2Model(GPT2Model):
191
185
 
192
186
  output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
193
187
 
194
- presents = ()
195
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
196
- outputs = _GPT2Block.forward(
188
+ for i, block in enumerate(self.h):
189
+ hidden_states = _GPT2Block.forward(
197
190
  block,
198
191
  hidden_states,
199
- layer_past=layer_past,
192
+ past_key_values=past_key_values,
200
193
  attention_mask=attention_mask,
201
194
  head_mask=head_mask[i],
202
195
  cache_position=cache_position,
203
196
  )
204
- hidden_states = outputs[0]
205
-
206
- presents = presents + (outputs[1],)
207
197
 
208
198
  hidden_states = self.ln_f(hidden_states)
209
199
  hidden_states = hidden_states.view(output_shape)
210
- return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=presents)
200
+ return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values)
211
201
 
212
202
 
213
203
  class GPT2LMHeadModelWrapper(torch.nn.Module):
@@ -218,13 +208,13 @@ class GPT2LMHeadModelWrapper(torch.nn.Module):
218
208
  def forward(
219
209
  self,
220
210
  input_ids: torch.Tensor,
221
- past_key_values: torch.Tensor,
222
211
  attention_mask: torch.Tensor,
223
212
  cache_position: torch.LongTensor,
213
+ *past_key_values: torch.Tensor,
224
214
  ):
225
215
  kv_cache = []
226
216
  for i in range(self.model.config.n_layer):
227
- kv_cache.append((past_key_values[i, 0], past_key_values[i, 1]))
217
+ kv_cache.append((past_key_values[2 * i], past_key_values[2 * i + 1]))
228
218
 
229
219
  transformer_outputs = _GPT2Model.forward(
230
220
  self.model.transformer,
@@ -247,7 +237,8 @@ class GPT2LMHeadModelWrapper(torch.nn.Module):
247
237
 
248
238
  past_key_values = []
249
239
  for i in range(self.model.config.n_layer):
250
- past_key_values.append(torch.stack(kv_cache[i], dim=0))
251
- past_key_values = torch.stack(past_key_values, dim=0)
240
+ past_key_values.append(kv_cache[i][0])
241
+ past_key_values.append(kv_cache[i][1])
252
242
 
253
- return lm_logits, past_key_values
243
+ output = (lm_logits,) + tuple(past_key_values)
244
+ return output