optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,17 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from
|
27
|
-
|
28
|
-
|
29
|
-
import
|
30
|
-
|
31
|
-
|
32
|
-
BartConfig,
|
33
|
-
BartForConditionalGeneration,
|
34
|
-
PretrainedConfig,
|
35
|
-
T5ForConditionalGeneration,
|
36
|
-
)
|
26
|
+
from abc import ABC
|
27
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
28
|
+
|
29
|
+
import rebel # noqa: F401
|
30
|
+
import torch # noqa: F401
|
31
|
+
from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
|
37
32
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
38
33
|
|
39
|
-
from
|
40
|
-
from
|
41
|
-
from .
|
42
|
-
from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
|
43
|
-
from .utils.runtime_utils import RBLNPytorchRuntime
|
34
|
+
from ....modeling_base import RBLNModel
|
35
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
36
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
44
37
|
|
45
38
|
|
46
39
|
logger = logging.getLogger(__name__)
|
@@ -59,7 +52,6 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
59
52
|
|
60
53
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
61
54
|
_ = super().forward(*args, **kwargs)
|
62
|
-
# Just indicates that it is not None
|
63
55
|
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
64
56
|
|
65
57
|
|
@@ -71,7 +63,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
71
63
|
return Seq2SeqLMOutput(logits=outputs)
|
72
64
|
|
73
65
|
|
74
|
-
class RBLNModelForSeq2SeqLM(RBLNModel):
|
66
|
+
class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
75
67
|
"""
|
76
68
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
77
69
|
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -84,136 +76,59 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
84
76
|
Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
|
85
77
|
"""
|
86
78
|
|
79
|
+
main_input_name = "input_ids"
|
87
80
|
auto_model_class = AutoModelForSeq2SeqLM
|
88
81
|
|
89
82
|
def __post_init__(self, **kwargs):
|
90
|
-
self.model_dim = self.config.d_model
|
91
|
-
self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
|
92
|
-
self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
|
93
|
-
self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
|
94
|
-
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
95
83
|
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
|
96
84
|
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
97
85
|
|
98
|
-
|
99
|
-
|
86
|
+
@classmethod
|
87
|
+
@torch.inference_mode()
|
88
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
|
89
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
100
90
|
|
101
|
-
|
102
|
-
|
91
|
+
wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
|
92
|
+
wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
103
93
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
def __getattr__(self, __name: str) -> Any:
|
108
|
-
def redirect(func):
|
109
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
94
|
+
wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
|
95
|
+
wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
110
96
|
|
111
|
-
|
112
|
-
|
113
|
-
else:
|
114
|
-
val = getattr(BartForConditionalGeneration, __name)
|
97
|
+
enc_rbln_compile_config = rbln_config.compile_cfgs[0]
|
98
|
+
dec_rbln_compile_config = rbln_config.compile_cfgs[1]
|
115
99
|
|
116
|
-
|
117
|
-
|
118
|
-
return val
|
100
|
+
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
|
101
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
119
102
|
|
120
|
-
|
121
|
-
|
122
|
-
input_ids,
|
123
|
-
past_key_values=None,
|
124
|
-
attention_mask=None,
|
125
|
-
decoder_attention_mask=None,
|
126
|
-
**kwargs,
|
127
|
-
):
|
128
|
-
max_seq_len = self.dec_max_seq_len
|
129
|
-
cur_seq_len = input_ids.shape[-1]
|
130
|
-
decoder_batch_size = input_ids.shape[0]
|
131
|
-
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
103
|
+
enc_example_inputs[3].fill_(0)
|
104
|
+
dec_example_inputs[4].fill_(-1)
|
132
105
|
|
133
|
-
|
134
|
-
|
135
|
-
decoder_attention_mask[:, :cur_seq_len] = 1
|
136
|
-
cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
|
137
|
-
|
138
|
-
return {
|
139
|
-
"decoder_input_ids": input_ids,
|
140
|
-
"past_key_values": past_key_values,
|
141
|
-
"attention_mask": attention_mask,
|
142
|
-
"decoder_attention_mask": decoder_attention_mask,
|
143
|
-
"cache_position": cache_position,
|
144
|
-
}
|
145
|
-
|
146
|
-
@classmethod
|
147
|
-
def update_kwargs(cls, kwargs):
|
148
|
-
kwargs.update(
|
149
|
-
{
|
150
|
-
"torchscript": True,
|
151
|
-
"return_dict": False,
|
152
|
-
"use_cache": True,
|
153
|
-
}
|
154
|
-
)
|
155
|
-
return kwargs
|
156
|
-
|
157
|
-
@classmethod
|
158
|
-
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
159
|
-
def optimized_models(model):
|
160
|
-
if isinstance(model, T5ForConditionalGeneration):
|
161
|
-
encoder_model = T5EncoderWrapper(model).eval()
|
162
|
-
decoder_model = T5DecoderWrapper(model).eval()
|
163
|
-
elif isinstance(model, BartForConditionalGeneration):
|
164
|
-
encoder_model = BartEncoderWrapper(model).eval()
|
165
|
-
decoder_model = BartDecoderWrapper(model).eval()
|
166
|
-
else:
|
167
|
-
raise ValueError(f"{model.__class__.__name__} is not supported yet.")
|
168
|
-
|
169
|
-
return encoder_model, decoder_model
|
170
|
-
|
171
|
-
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
172
|
-
|
173
|
-
wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
174
|
-
wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
175
|
-
wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
176
|
-
|
177
|
-
wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
178
|
-
wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
179
|
-
wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
180
|
-
|
181
|
-
enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
182
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
183
|
-
|
184
|
-
if isinstance(model, T5ForConditionalGeneration):
|
185
|
-
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
|
186
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
|
187
|
-
else:
|
188
|
-
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
|
189
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
190
|
-
|
191
|
-
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
|
192
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
106
|
+
enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
|
107
|
+
dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
|
193
108
|
|
194
109
|
enc_ir = rebel.torchscript_to_ir(
|
195
110
|
enc_scripted_model,
|
196
|
-
input_names=[v[0] for v in
|
197
|
-
name=
|
111
|
+
input_names=[v[0] for v in enc_rbln_compile_config.input_info],
|
112
|
+
name=enc_rbln_compile_config.mod_name,
|
198
113
|
)
|
199
114
|
dec_ir = rebel.torchscript_to_ir(
|
200
115
|
dec_scripted_model,
|
201
|
-
input_names=[v[0] for v in
|
202
|
-
name=
|
116
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
117
|
+
name=dec_rbln_compile_config.mod_name,
|
203
118
|
)
|
204
|
-
dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
205
119
|
|
206
120
|
connections = [
|
207
|
-
(enc_ir.outputs[0], dec_ir.inputs[
|
208
|
-
(dec_ir.outputs[1], dec_ir.inputs[
|
121
|
+
(enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
|
122
|
+
(dec_ir.outputs[1], dec_ir.inputs[5]),
|
209
123
|
]
|
124
|
+
|
210
125
|
compiled_model = rebel.compile(
|
211
126
|
enc_ir,
|
212
127
|
dec_ir,
|
213
128
|
connections=connections,
|
214
|
-
fusion=
|
215
|
-
npu=
|
216
|
-
tensor_parallel_size=
|
129
|
+
fusion=enc_rbln_compile_config.fusion,
|
130
|
+
npu=enc_rbln_compile_config.npu,
|
131
|
+
tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
|
217
132
|
)
|
218
133
|
return compiled_model
|
219
134
|
|
@@ -222,20 +137,20 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
222
137
|
cls,
|
223
138
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
224
139
|
model_config: "PretrainedConfig",
|
225
|
-
|
226
|
-
rbln_dec_max_seq_len: Optional[int] = None,
|
227
|
-
rbln_batch_size: Optional[int] = 1,
|
140
|
+
rbln_kwargs: Dict[str, Any] = {},
|
228
141
|
) -> RBLNConfig:
|
229
|
-
|
142
|
+
rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
|
143
|
+
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
144
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
145
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
230
146
|
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
d_kv = model_config.d_kv
|
147
|
+
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
148
|
+
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
149
|
+
d_kv = (
|
150
|
+
model_config.d_kv
|
151
|
+
if hasattr(model_config, "d_kv")
|
152
|
+
else model_config.d_model // model_config.encoder_attention_heads
|
153
|
+
)
|
239
154
|
|
240
155
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
241
156
|
model_config, "max_position_embeddings", None
|
@@ -274,28 +189,34 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
274
189
|
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
275
190
|
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
276
191
|
|
277
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
278
|
-
|
279
|
-
meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
|
280
|
-
meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
|
281
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
282
|
-
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
283
|
-
|
284
192
|
# model input info
|
285
193
|
enc_input_info = [
|
286
|
-
("input_ids", [
|
287
|
-
("attention_mask", [
|
194
|
+
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
195
|
+
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
196
|
+
(
|
197
|
+
"cross_key_value_states",
|
198
|
+
[
|
199
|
+
n_layer * 2,
|
200
|
+
rbln_batch_size,
|
201
|
+
n_head,
|
202
|
+
rbln_enc_max_seq_len,
|
203
|
+
d_kv,
|
204
|
+
],
|
205
|
+
"float32",
|
206
|
+
),
|
207
|
+
("batch_idx", [], "int32"),
|
288
208
|
]
|
289
209
|
|
290
210
|
dec_input_info = [
|
291
211
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
292
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "
|
293
|
-
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "
|
212
|
+
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
213
|
+
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
294
214
|
(
|
295
215
|
"cache_position",
|
296
|
-
[],
|
216
|
+
[rbln_batch_size, 1],
|
297
217
|
"int32",
|
298
218
|
),
|
219
|
+
("batch_position", [], "int32"),
|
299
220
|
]
|
300
221
|
dec_input_info.extend(
|
301
222
|
[
|
@@ -327,12 +248,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
327
248
|
)
|
328
249
|
]
|
329
250
|
)
|
330
|
-
|
331
|
-
|
251
|
+
enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
|
252
|
+
dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
|
332
253
|
|
333
|
-
rbln_config = RBLNConfig
|
334
|
-
|
335
|
-
|
254
|
+
rbln_config = RBLNConfig(
|
255
|
+
rbln_cls=cls.__name__,
|
256
|
+
compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
|
257
|
+
rbln_kwargs=rbln_kwargs,
|
258
|
+
)
|
259
|
+
|
260
|
+
rbln_config.model_cfg.update(
|
261
|
+
{
|
262
|
+
"enc_max_seq_len": rbln_enc_max_seq_len,
|
263
|
+
"dec_max_seq_len": rbln_dec_max_seq_len,
|
264
|
+
"batch_size": rbln_batch_size,
|
265
|
+
"pad_token_id": rbln_pad_token_id,
|
266
|
+
}
|
336
267
|
)
|
337
268
|
|
338
269
|
return rbln_config
|
@@ -347,7 +278,52 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
347
278
|
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
348
279
|
]
|
349
280
|
|
281
|
+
def can_generate(self):
|
282
|
+
return True
|
283
|
+
|
284
|
+
def get_encoder(self):
|
285
|
+
return self.encoder
|
286
|
+
|
287
|
+
def get_decoder(self):
|
288
|
+
return self.decoder
|
289
|
+
|
290
|
+
def prepare_inputs_for_generation(
|
291
|
+
self,
|
292
|
+
input_ids,
|
293
|
+
attention_mask=None,
|
294
|
+
decoder_attention_mask=None,
|
295
|
+
**kwargs,
|
296
|
+
):
|
297
|
+
cur_seq_len = input_ids.shape[-1]
|
298
|
+
cache_position = cur_seq_len - 1
|
299
|
+
max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
300
|
+
decoder_batch_size = input_ids.shape[0]
|
301
|
+
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
302
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
303
|
+
decoder_attention_mask[:, :cur_seq_len] = 1
|
304
|
+
|
305
|
+
return {
|
306
|
+
"decoder_input_ids": input_ids,
|
307
|
+
"attention_mask": attention_mask.to(torch.float32),
|
308
|
+
"decoder_attention_mask": decoder_attention_mask,
|
309
|
+
"cache_position": cache_position,
|
310
|
+
}
|
311
|
+
|
350
312
|
def forward(
|
313
|
+
self,
|
314
|
+
input_ids: torch.LongTensor = None,
|
315
|
+
cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
|
316
|
+
**kwargs,
|
317
|
+
) -> Tuple[torch.FloatTensor]:
|
318
|
+
# common decoder
|
319
|
+
cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
|
320
|
+
logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
|
321
|
+
|
322
|
+
return Seq2SeqLMOutput(
|
323
|
+
logits=logits,
|
324
|
+
)
|
325
|
+
|
326
|
+
def _forward_decoder(
|
351
327
|
self,
|
352
328
|
attention_mask: Optional[torch.FloatTensor] = None,
|
353
329
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
@@ -355,35 +331,73 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
355
331
|
cache_position: Optional[torch.Tensor] = None,
|
356
332
|
**kwargs,
|
357
333
|
) -> Tuple[torch.FloatTensor]:
|
334
|
+
dec_attention_mask = decoder_attention_mask.clone()
|
335
|
+
for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
336
|
+
dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
|
337
|
+
|
358
338
|
decoder_output = self.decoder(
|
359
339
|
input_ids=decoder_input_ids,
|
360
|
-
attention_mask=
|
340
|
+
attention_mask=dec_attention_mask,
|
361
341
|
encoder_attention_mask=attention_mask,
|
362
342
|
cache_position=cache_position,
|
343
|
+
batch_position=torch.tensor(0, dtype=torch.int32),
|
363
344
|
)
|
364
|
-
lm_logits = decoder_output.logits
|
345
|
+
lm_logits = decoder_output.logits[0]
|
365
346
|
|
366
347
|
return Seq2SeqLMOutput(logits=lm_logits)
|
367
348
|
|
349
|
+
def vllm_forward(
|
350
|
+
self,
|
351
|
+
input_ids: torch.LongTensor = None,
|
352
|
+
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
353
|
+
batch_idx: Optional[torch.LongTensor] = None,
|
354
|
+
enc_lengths: List[int] = None, # vllm return current attention_mask length
|
355
|
+
**kwargs,
|
356
|
+
) -> Tuple[torch.FloatTensor]:
|
357
|
+
# When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
|
358
|
+
# encoder
|
359
|
+
if batch_idx is not None:
|
360
|
+
enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
|
361
|
+
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
362
|
+
padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
|
363
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
364
|
+
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
365
|
+
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
366
|
+
logits[0][0][-1] = 1
|
367
|
+
# decoder
|
368
|
+
else:
|
369
|
+
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
370
|
+
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
371
|
+
|
372
|
+
enc_attention_mask = torch.zeros(
|
373
|
+
self.rbln_config.model_cfg["batch_size"],
|
374
|
+
self.rbln_config.model_cfg["enc_max_seq_len"],
|
375
|
+
dtype=torch.float32,
|
376
|
+
)
|
377
|
+
dec_attention_mask = torch.zeros(
|
378
|
+
self.rbln_config.model_cfg["batch_size"],
|
379
|
+
self.rbln_config.model_cfg["dec_max_seq_len"],
|
380
|
+
dtype=torch.float32,
|
381
|
+
)
|
382
|
+
for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
383
|
+
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
384
|
+
|
385
|
+
logits = self._forward_decoder(
|
386
|
+
attention_mask=enc_attention_mask,
|
387
|
+
decoder_input_ids=input_ids,
|
388
|
+
decoder_attention_mask=dec_attention_mask,
|
389
|
+
cache_position=cache_position,
|
390
|
+
).logits
|
391
|
+
|
392
|
+
return Seq2SeqLMOutput(logits=logits)
|
393
|
+
|
368
394
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
369
395
|
self,
|
370
396
|
inputs_tensor: torch.Tensor,
|
371
397
|
model_kwargs,
|
372
398
|
model_input_name: Optional[str] = None,
|
373
|
-
|
374
|
-
**kwargs,
|
399
|
+
generation_config: Optional[GenerationConfig] = None,
|
375
400
|
) -> Dict[str, Any]:
|
376
|
-
########## thkim change start ###################
|
377
|
-
# padding input_ids & attention_mask regardless of user's tokenizer usage
|
378
|
-
batch_size, input_len = inputs_tensor.shape
|
379
|
-
inputs_tensor = torch.nn.functional.pad(
|
380
|
-
inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
|
381
|
-
)
|
382
|
-
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
383
|
-
model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
|
384
|
-
)
|
385
|
-
########## thkim change end ###################
|
386
|
-
|
387
401
|
# 1. get encoder
|
388
402
|
encoder = self.get_encoder()
|
389
403
|
|
@@ -401,10 +415,26 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
401
415
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
402
416
|
}
|
403
417
|
|
418
|
+
batch_size, input_len = inputs_tensor.shape
|
419
|
+
inputs_tensor = torch.nn.functional.pad(
|
420
|
+
inputs_tensor,
|
421
|
+
(0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
|
422
|
+
value=self.rbln_config.model_cfg["pad_token_id"],
|
423
|
+
)
|
424
|
+
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
425
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
|
426
|
+
)
|
427
|
+
|
404
428
|
# 3. make sure that encoder returns `ModelOutput`
|
405
429
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
406
430
|
encoder_kwargs["return_dict"] = True
|
407
|
-
encoder_kwargs[
|
408
|
-
|
431
|
+
encoder_kwargs["output_hidden_states"] = False
|
432
|
+
encoder_kwargs["output_attentions"] = False
|
433
|
+
|
434
|
+
for b in range(batch_size):
|
435
|
+
batch_idx = torch.tensor(b, dtype=torch.int32)
|
436
|
+
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
437
|
+
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
438
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
|
409
439
|
|
410
440
|
return model_kwargs
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable
|
26
|
+
|
27
|
+
from transformers import T5ForConditionalGeneration
|
28
|
+
|
29
|
+
from ....modeling_config import RBLNConfig
|
30
|
+
from ....utils.logging import get_logger
|
31
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
32
|
+
from .t5_architecture import T5Wrapper
|
33
|
+
|
34
|
+
|
35
|
+
logger = get_logger()
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import PreTrainedModel
|
39
|
+
|
40
|
+
|
41
|
+
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
42
|
+
@classmethod
|
43
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
44
|
+
return T5Wrapper(model)
|
45
|
+
|
46
|
+
def __getattr__(self, __name: str) -> Any:
|
47
|
+
def redirect(func):
|
48
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
49
|
+
|
50
|
+
val = getattr(T5ForConditionalGeneration, __name)
|
51
|
+
|
52
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
53
|
+
return redirect(val)
|
54
|
+
|
55
|
+
return val
|