optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +7 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_base.py +172 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
- optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
- optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
- optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_seq2seq.py
CHANGED
@@ -23,13 +23,10 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
27
|
|
30
28
|
import rebel
|
31
29
|
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
30
|
from transformers import (
|
34
31
|
AutoModelForSeq2SeqLM,
|
35
32
|
BartConfig,
|
@@ -39,12 +36,11 @@ from transformers import (
|
|
39
36
|
)
|
40
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
41
38
|
|
42
|
-
from .modeling_base import
|
39
|
+
from .modeling_base import RBLNModel
|
43
40
|
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
44
41
|
from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
|
45
42
|
from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
|
46
43
|
from .utils.runtime_utils import RBLNPytorchRuntime
|
47
|
-
from .utils.save_utils import maybe_save_preprocessors
|
48
44
|
|
49
45
|
|
50
46
|
logger = logging.getLogger(__name__)
|
@@ -75,7 +71,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
75
71
|
return Seq2SeqLMOutput(logits=outputs)
|
76
72
|
|
77
73
|
|
78
|
-
class RBLNModelForSeq2SeqLM(
|
74
|
+
class RBLNModelForSeq2SeqLM(RBLNModel):
|
79
75
|
"""
|
80
76
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
81
77
|
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -88,7 +84,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
88
84
|
Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
|
89
85
|
"""
|
90
86
|
|
91
|
-
model_type = "rbln_model"
|
92
87
|
auto_model_class = AutoModelForSeq2SeqLM
|
93
88
|
|
94
89
|
def __post_init__(self, **kwargs):
|
@@ -97,8 +92,8 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
97
92
|
self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
|
98
93
|
self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
|
99
94
|
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
100
|
-
self.encoder = RBLNRuntimeEncoder(runtime=self.
|
101
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.
|
95
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
|
96
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
102
97
|
|
103
98
|
def can_generate(self):
|
104
99
|
return True
|
@@ -149,74 +144,18 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
149
144
|
}
|
150
145
|
|
151
146
|
@classmethod
|
152
|
-
def
|
153
|
-
cls,
|
154
|
-
model_id: str,
|
155
|
-
config: "PretrainedConfig",
|
156
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
157
|
-
revision: Optional[str] = None,
|
158
|
-
force_download: bool = False,
|
159
|
-
cache_dir: Optional[str] = None,
|
160
|
-
subfolder: str = "",
|
161
|
-
local_files_only: bool = False,
|
162
|
-
trust_remote_code: bool = False,
|
163
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
164
|
-
**kwargs,
|
165
|
-
) -> "AutoModelForSeq2SeqLM":
|
166
|
-
"""
|
167
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
168
|
-
"""
|
169
|
-
task = kwargs.pop("task", None)
|
170
|
-
if task is None:
|
171
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
172
|
-
|
173
|
-
if model_save_dir is None:
|
174
|
-
save_dir = TemporaryDirectory()
|
175
|
-
save_dir_path = Path(save_dir.name)
|
176
|
-
else:
|
177
|
-
save_dir = model_save_dir
|
178
|
-
if isinstance(save_dir, TemporaryDirectory):
|
179
|
-
save_dir_path = Path(model_save_dir.name)
|
180
|
-
else:
|
181
|
-
save_dir_path = Path(model_save_dir)
|
182
|
-
save_dir_path.mkdir(exist_ok=True)
|
183
|
-
|
147
|
+
def update_kwargs(cls, kwargs):
|
184
148
|
kwargs.update(
|
185
149
|
{
|
186
150
|
"torchscript": True,
|
187
151
|
"return_dict": False,
|
188
|
-
"use_cache":
|
152
|
+
"use_cache": True,
|
189
153
|
}
|
190
154
|
)
|
155
|
+
return kwargs
|
191
156
|
|
192
|
-
|
193
|
-
|
194
|
-
model: AutoModelForSeq2SeqLM = TasksManager.get_model_from_task(
|
195
|
-
task=task,
|
196
|
-
model_name_or_path=model_id,
|
197
|
-
subfolder=subfolder,
|
198
|
-
revision=revision,
|
199
|
-
framework="pt",
|
200
|
-
cache_dir=cache_dir,
|
201
|
-
use_auth_token=use_auth_token,
|
202
|
-
local_files_only=local_files_only,
|
203
|
-
force_download=force_download,
|
204
|
-
trust_remote_code=trust_remote_code,
|
205
|
-
**kwargs,
|
206
|
-
)
|
207
|
-
|
208
|
-
if config is None:
|
209
|
-
config = model.config
|
210
|
-
|
211
|
-
config.save_pretrained(save_dir_path)
|
212
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
213
|
-
|
214
|
-
# Get compilation arguments
|
215
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
216
|
-
rbln_config = cls.get_rbln_config(
|
217
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
218
|
-
)
|
219
|
-
|
157
|
+
@classmethod
|
158
|
+
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
220
159
|
def optimized_models(model):
|
221
160
|
if isinstance(model, T5ForConditionalGeneration):
|
222
161
|
encoder_model = T5EncoderWrapper(model).eval()
|
@@ -229,67 +168,54 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
229
168
|
|
230
169
|
return encoder_model, decoder_model
|
231
170
|
|
232
|
-
|
233
|
-
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
171
|
+
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
234
172
|
|
235
|
-
|
236
|
-
|
237
|
-
|
173
|
+
wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
174
|
+
wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
175
|
+
wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
238
176
|
|
239
|
-
|
240
|
-
|
241
|
-
|
177
|
+
wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
178
|
+
wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
179
|
+
wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
242
180
|
|
243
|
-
|
244
|
-
|
181
|
+
enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
182
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
245
183
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
)
|
266
|
-
dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
267
|
-
|
268
|
-
connections = [
|
269
|
-
(enc_ir.outputs[0], dec_ir.inputs[5]),
|
270
|
-
(dec_ir.outputs[1], dec_ir.inputs[4]),
|
271
|
-
]
|
272
|
-
compiled_model = rebel.compile(
|
273
|
-
enc_ir,
|
274
|
-
dec_ir,
|
275
|
-
connections=connections,
|
276
|
-
fusion=enc_rbln_runtime_config.fusion,
|
277
|
-
npu=enc_rbln_runtime_config.npu,
|
278
|
-
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
279
|
-
)
|
280
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
281
|
-
|
282
|
-
compile()
|
283
|
-
|
284
|
-
rbln_config.save(save_dir_path)
|
285
|
-
|
286
|
-
return cls._from_pretrained(
|
287
|
-
model_id=save_dir_path,
|
288
|
-
config=config,
|
289
|
-
model_save_dir=save_dir,
|
290
|
-
**rbln_constructor_kwargs,
|
291
|
-
**kwargs,
|
184
|
+
if isinstance(model, T5ForConditionalGeneration):
|
185
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
|
186
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
|
187
|
+
else:
|
188
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
|
189
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
190
|
+
|
191
|
+
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
|
192
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
193
|
+
|
194
|
+
enc_ir = rebel.torchscript_to_ir(
|
195
|
+
enc_scripted_model,
|
196
|
+
input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
|
197
|
+
name=enc_rbln_runtime_config.rbln_mod_name,
|
198
|
+
)
|
199
|
+
dec_ir = rebel.torchscript_to_ir(
|
200
|
+
dec_scripted_model,
|
201
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
202
|
+
name=dec_rbln_runtime_config.rbln_mod_name,
|
292
203
|
)
|
204
|
+
dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
205
|
+
|
206
|
+
connections = [
|
207
|
+
(enc_ir.outputs[0], dec_ir.inputs[5]),
|
208
|
+
(dec_ir.outputs[1], dec_ir.inputs[4]),
|
209
|
+
]
|
210
|
+
compiled_model = rebel.compile(
|
211
|
+
enc_ir,
|
212
|
+
dec_ir,
|
213
|
+
connections=connections,
|
214
|
+
fusion=enc_rbln_runtime_config.fusion,
|
215
|
+
npu=enc_rbln_runtime_config.npu,
|
216
|
+
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
217
|
+
)
|
218
|
+
return compiled_model
|
293
219
|
|
294
220
|
@classmethod
|
295
221
|
def _get_rbln_config(
|
@@ -411,11 +337,14 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
411
337
|
|
412
338
|
return rbln_config
|
413
339
|
|
414
|
-
|
340
|
+
@classmethod
|
341
|
+
def _create_runtimes(
|
342
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
343
|
+
) -> List[rebel.Runtime]:
|
415
344
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
416
345
|
return [
|
417
|
-
|
418
|
-
|
346
|
+
compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
|
347
|
+
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
419
348
|
]
|
420
349
|
|
421
350
|
def forward(
|
@@ -436,9 +365,6 @@ class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
|
436
365
|
|
437
366
|
return Seq2SeqLMOutput(logits=lm_logits)
|
438
367
|
|
439
|
-
def __repr__(self):
|
440
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
441
|
-
|
442
368
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
443
369
|
self,
|
444
370
|
inputs_tensor: torch.Tensor,
|
@@ -31,6 +31,7 @@ _import_structure = {
|
|
31
31
|
"models": [
|
32
32
|
"RBLNCLIPTextModel",
|
33
33
|
"RBLNCLIPTextModelWithProjection",
|
34
|
+
"RBLNDPTForDepthEstimation",
|
34
35
|
"RBLNGPT2LMHeadModel",
|
35
36
|
"RBLNWav2Vec2ForCTC",
|
36
37
|
"RBLNWhisperForConditionalGeneration",
|
@@ -44,6 +45,7 @@ if TYPE_CHECKING:
|
|
44
45
|
from .models import (
|
45
46
|
RBLNCLIPTextModel,
|
46
47
|
RBLNCLIPTextModelWithProjection,
|
48
|
+
RBLNDPTForDepthEstimation,
|
47
49
|
RBLNGPT2LMHeadModel,
|
48
50
|
RBLNLlamaForCausalLM,
|
49
51
|
RBLNMidmLMHeadModel,
|
@@ -22,6 +22,7 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
|
25
|
+
from .dpt import RBLNDPTForDepthEstimation
|
25
26
|
from .gpt2 import RBLNGPT2LMHeadModel
|
26
27
|
from .llama import RBLNLlamaForCausalLM
|
27
28
|
from .midm import RBLNMidmLMHeadModel
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_dpt import RBLNDPTForDepthEstimation
|
@@ -0,0 +1,89 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from typing import TYPE_CHECKING, Iterable, Optional, Union
|
26
|
+
|
27
|
+
from transformers import AutoModelForDepthEstimation
|
28
|
+
from transformers.modeling_outputs import DepthEstimatorOutput
|
29
|
+
|
30
|
+
from ....modeling_base import RBLNModel
|
31
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
if TYPE_CHECKING:
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNDPTForDepthEstimation(RBLNModel):
|
41
|
+
model_type = "rbln_model"
|
42
|
+
auto_model_class = AutoModelForDepthEstimation
|
43
|
+
main_input_name = "pixel_values"
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def _get_rbln_config(
|
47
|
+
cls,
|
48
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
49
|
+
model_config: Optional["PretrainedConfig"] = None,
|
50
|
+
rbln_image_size: Optional[int] = None,
|
51
|
+
rbln_batch_size: Optional[int] = None,
|
52
|
+
) -> RBLNConfig:
|
53
|
+
if rbln_batch_size is None:
|
54
|
+
rbln_batch_size = 1
|
55
|
+
|
56
|
+
if rbln_image_size is None:
|
57
|
+
for processor in preprocessors:
|
58
|
+
image_size = getattr(processor, "size", None)
|
59
|
+
|
60
|
+
if image_size is not None:
|
61
|
+
if isinstance(image_size, Iterable):
|
62
|
+
if "shortest_edge" in image_size:
|
63
|
+
rbln_image_size = image_size["shortest_edge"]
|
64
|
+
break
|
65
|
+
elif "height" in image_size and "width" in image_size:
|
66
|
+
rbln_image_size = image_size["height"], image_size["width"]
|
67
|
+
break
|
68
|
+
else:
|
69
|
+
rbln_image_size = image_size
|
70
|
+
|
71
|
+
if rbln_image_size is None:
|
72
|
+
rbln_image_size = getattr(model_config, "image_size", None)
|
73
|
+
|
74
|
+
if rbln_image_size is None:
|
75
|
+
raise ValueError("`rbln_image_size` should be specified!")
|
76
|
+
|
77
|
+
if isinstance(rbln_image_size, int):
|
78
|
+
rbln_image_size = rbln_image_size, rbln_image_size
|
79
|
+
|
80
|
+
input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size[0], rbln_image_size[1]], "float32")]
|
81
|
+
|
82
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
83
|
+
meta = {"rbln_image_size": rbln_image_size}
|
84
|
+
|
85
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
86
|
+
|
87
|
+
def forward(self, *args, **kwargs):
|
88
|
+
predicted_depth = super().forward(*args, **kwargs)
|
89
|
+
return DepthEstimatorOutput(predicted_depth=predicted_depth)
|
@@ -21,7 +21,7 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import Optional, Tuple, Union
|
24
|
+
from typing import List, Optional, Tuple, Union
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -79,7 +79,7 @@ class _GPT2Attention(GPT2Attention):
|
|
79
79
|
def forward(
|
80
80
|
self,
|
81
81
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
82
|
-
|
82
|
+
past_key_values: List[List[torch.Tensor]] = None,
|
83
83
|
attention_mask: Optional[torch.FloatTensor] = None,
|
84
84
|
head_mask: Optional[torch.FloatTensor] = None,
|
85
85
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
@@ -94,16 +94,15 @@ class _GPT2Attention(GPT2Attention):
|
|
94
94
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
95
95
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
96
96
|
|
97
|
-
if
|
98
|
-
past_key, past_value =
|
97
|
+
if past_key_values is not None:
|
98
|
+
past_key, past_value = past_key_values[self.layer_idx]
|
99
99
|
query_length = query.shape[-2]
|
100
100
|
|
101
|
-
key =
|
102
|
-
value =
|
103
|
-
|
104
|
-
|
101
|
+
key = past_key.slice_scatter(key, dim=2, start=cache_position, end=cache_position + query_length)
|
102
|
+
value = past_value.slice_scatter(value, dim=2, start=cache_position, end=cache_position + query_length)
|
103
|
+
|
104
|
+
past_key_values[self.layer_idx] = [key, value]
|
105
105
|
|
106
|
-
present = (key, value)
|
107
106
|
attn_output, _ = _GPT2Attention._attn(self, query, key, value, attention_mask, head_mask)
|
108
107
|
|
109
108
|
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
@@ -111,16 +110,14 @@ class _GPT2Attention(GPT2Attention):
|
|
111
110
|
attn_output = self.c_proj(attn_output)
|
112
111
|
attn_output = self.resid_dropout(attn_output)
|
113
112
|
|
114
|
-
|
115
|
-
|
116
|
-
return outputs
|
113
|
+
return attn_output
|
117
114
|
|
118
115
|
|
119
116
|
class _GPT2Block(GPT2Block):
|
120
117
|
def forward(
|
121
118
|
self,
|
122
119
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
123
|
-
|
120
|
+
past_key_values: List[List[torch.Tensor]] = None,
|
124
121
|
attention_mask: Optional[torch.FloatTensor] = None,
|
125
122
|
head_mask: Optional[torch.FloatTensor] = None,
|
126
123
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
@@ -132,17 +129,15 @@ class _GPT2Block(GPT2Block):
|
|
132
129
|
residual = hidden_states
|
133
130
|
hidden_states = self.ln_1(hidden_states)
|
134
131
|
|
135
|
-
|
132
|
+
attn_output = _GPT2Attention.forward(
|
136
133
|
self.attn,
|
137
134
|
hidden_states,
|
138
|
-
|
135
|
+
past_key_values=past_key_values,
|
139
136
|
attention_mask=attention_mask,
|
140
137
|
head_mask=head_mask,
|
141
138
|
cache_position=cache_position,
|
142
139
|
)
|
143
140
|
|
144
|
-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
145
|
-
outputs = attn_outputs[1:]
|
146
141
|
# residual connection
|
147
142
|
hidden_states = attn_output + residual
|
148
143
|
|
@@ -152,15 +147,14 @@ class _GPT2Block(GPT2Block):
|
|
152
147
|
# residual connection
|
153
148
|
hidden_states = residual + feed_forward_hidden_states
|
154
149
|
|
155
|
-
|
156
|
-
return outputs # hidden_states, present, (attentions, cross_attentions)
|
150
|
+
return hidden_states
|
157
151
|
|
158
152
|
|
159
153
|
class _GPT2Model(GPT2Model):
|
160
154
|
def forward(
|
161
155
|
self,
|
162
156
|
input_ids: Optional[torch.LongTensor] = None,
|
163
|
-
past_key_values:
|
157
|
+
past_key_values: List[List[torch.Tensor]] = None,
|
164
158
|
attention_mask: Optional[torch.FloatTensor] = None,
|
165
159
|
position_ids: Optional[torch.LongTensor] = None,
|
166
160
|
head_mask: Optional[torch.FloatTensor] = None,
|
@@ -191,23 +185,19 @@ class _GPT2Model(GPT2Model):
|
|
191
185
|
|
192
186
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
193
187
|
|
194
|
-
|
195
|
-
|
196
|
-
outputs = _GPT2Block.forward(
|
188
|
+
for i, block in enumerate(self.h):
|
189
|
+
hidden_states = _GPT2Block.forward(
|
197
190
|
block,
|
198
191
|
hidden_states,
|
199
|
-
|
192
|
+
past_key_values=past_key_values,
|
200
193
|
attention_mask=attention_mask,
|
201
194
|
head_mask=head_mask[i],
|
202
195
|
cache_position=cache_position,
|
203
196
|
)
|
204
|
-
hidden_states = outputs[0]
|
205
|
-
|
206
|
-
presents = presents + (outputs[1],)
|
207
197
|
|
208
198
|
hidden_states = self.ln_f(hidden_states)
|
209
199
|
hidden_states = hidden_states.view(output_shape)
|
210
|
-
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=
|
200
|
+
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values)
|
211
201
|
|
212
202
|
|
213
203
|
class GPT2LMHeadModelWrapper(torch.nn.Module):
|
@@ -218,13 +208,13 @@ class GPT2LMHeadModelWrapper(torch.nn.Module):
|
|
218
208
|
def forward(
|
219
209
|
self,
|
220
210
|
input_ids: torch.Tensor,
|
221
|
-
past_key_values: torch.Tensor,
|
222
211
|
attention_mask: torch.Tensor,
|
223
212
|
cache_position: torch.LongTensor,
|
213
|
+
*past_key_values: torch.Tensor,
|
224
214
|
):
|
225
215
|
kv_cache = []
|
226
216
|
for i in range(self.model.config.n_layer):
|
227
|
-
kv_cache.append((past_key_values[i
|
217
|
+
kv_cache.append((past_key_values[2 * i], past_key_values[2 * i + 1]))
|
228
218
|
|
229
219
|
transformer_outputs = _GPT2Model.forward(
|
230
220
|
self.model.transformer,
|
@@ -247,7 +237,8 @@ class GPT2LMHeadModelWrapper(torch.nn.Module):
|
|
247
237
|
|
248
238
|
past_key_values = []
|
249
239
|
for i in range(self.model.config.n_layer):
|
250
|
-
past_key_values.append(
|
251
|
-
|
240
|
+
past_key_values.append(kv_cache[i][0])
|
241
|
+
past_key_values.append(kv_cache[i][1])
|
252
242
|
|
253
|
-
|
243
|
+
output = (lm_logits,) + tuple(past_key_values)
|
244
|
+
return output
|