optimum-rbln 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +10 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +0 -2
- optimum/rbln/diffusers/models/controlnet.py +0 -6
- optimum/rbln/diffusers/models/unet_2d_condition.py +0 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +18 -20
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -20
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +19 -34
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +20 -35
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +12 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +13 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +13 -14
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +105 -139
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/models/__init__.py +4 -1
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +172 -100
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +148 -152
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +37 -12
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/utils/import_utils.py +14 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +4 -3
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/RECORD +54 -44
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,17 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from
|
27
|
-
|
28
|
-
|
29
|
-
import
|
30
|
-
|
31
|
-
|
32
|
-
BartConfig,
|
33
|
-
BartForConditionalGeneration,
|
34
|
-
PretrainedConfig,
|
35
|
-
T5ForConditionalGeneration,
|
36
|
-
)
|
26
|
+
from abc import ABC
|
27
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
28
|
+
|
29
|
+
import rebel # noqa: F401
|
30
|
+
import torch # noqa: F401
|
31
|
+
from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
|
37
32
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
38
33
|
|
39
|
-
from
|
40
|
-
from
|
41
|
-
from .
|
42
|
-
from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
|
43
|
-
from .utils.runtime_utils import RBLNPytorchRuntime
|
34
|
+
from ....modeling_base import RBLNModel
|
35
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
36
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
44
37
|
|
45
38
|
|
46
39
|
logger = logging.getLogger(__name__)
|
@@ -59,7 +52,6 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
59
52
|
|
60
53
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
61
54
|
_ = super().forward(*args, **kwargs)
|
62
|
-
# Just indicates that it is not None
|
63
55
|
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
64
56
|
|
65
57
|
|
@@ -71,7 +63,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
71
63
|
return Seq2SeqLMOutput(logits=outputs)
|
72
64
|
|
73
65
|
|
74
|
-
class RBLNModelForSeq2SeqLM(RBLNModel):
|
66
|
+
class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
75
67
|
"""
|
76
68
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
77
69
|
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -84,91 +76,35 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
84
76
|
Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
|
85
77
|
"""
|
86
78
|
|
79
|
+
main_input_name = "input_ids"
|
87
80
|
auto_model_class = AutoModelForSeq2SeqLM
|
88
81
|
|
89
82
|
def __post_init__(self, **kwargs):
|
90
|
-
self.model_dim = self.config.d_model
|
91
|
-
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
92
|
-
self.enc_max_seq_len = self.rbln_config.model_cfg["enc_max_seq_len"]
|
93
|
-
self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
94
|
-
self.pad_token_id = self.rbln_config.model_cfg["pad_token_id"]
|
95
83
|
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
|
96
84
|
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
97
|
-
self.enc_attention_mask = torch.zeros(1, self.enc_max_seq_len, dtype=torch.float32)
|
98
|
-
self.dec_enc_attention_mask = torch.zeros(self.batch_size, self.enc_max_seq_len, dtype=torch.float32)
|
99
|
-
|
100
|
-
def can_generate(self):
|
101
|
-
return True
|
102
|
-
|
103
|
-
def get_encoder(self):
|
104
|
-
return self.encoder
|
105
|
-
|
106
|
-
def get_decoder(self):
|
107
|
-
return self.decoder
|
108
|
-
|
109
|
-
def __getattr__(self, __name: str) -> Any:
|
110
|
-
def redirect(func):
|
111
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
112
|
-
|
113
|
-
if "T5ForConditionalGeneration" == self.config.architectures:
|
114
|
-
val = getattr(T5ForConditionalGeneration, __name)
|
115
|
-
else:
|
116
|
-
val = getattr(BartForConditionalGeneration, __name)
|
117
|
-
|
118
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
119
|
-
return redirect(val)
|
120
|
-
return val
|
121
|
-
|
122
|
-
@classmethod
|
123
|
-
def update_kwargs(cls, kwargs):
|
124
|
-
kwargs.update(
|
125
|
-
{
|
126
|
-
"torchscript": True,
|
127
|
-
"return_dict": False,
|
128
|
-
"use_cache": True,
|
129
|
-
}
|
130
|
-
)
|
131
|
-
return kwargs
|
132
85
|
|
133
86
|
@classmethod
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
encoder_model = T5EncoderWrapper(model).eval()
|
138
|
-
decoder_model = T5DecoderWrapper(model).eval()
|
139
|
-
elif isinstance(model, BartForConditionalGeneration):
|
140
|
-
encoder_model = BartEncoderWrapper(model).eval()
|
141
|
-
decoder_model = BartDecoderWrapper(model).eval()
|
142
|
-
else:
|
143
|
-
raise ValueError(f"{model.__class__.__name__} is not supported yet.")
|
144
|
-
|
145
|
-
return encoder_model, decoder_model
|
146
|
-
|
147
|
-
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
87
|
+
@torch.inference_mode()
|
88
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
|
89
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
148
90
|
|
149
|
-
|
150
|
-
|
151
|
-
wrapped_encoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
|
91
|
+
wrapped_model.encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
|
92
|
+
wrapped_model.encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
152
93
|
|
153
|
-
|
154
|
-
|
155
|
-
wrapped_decoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
|
94
|
+
wrapped_model.decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
|
95
|
+
wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
156
96
|
|
157
97
|
enc_rbln_compile_config = rbln_config.compile_cfgs[0]
|
158
98
|
dec_rbln_compile_config = rbln_config.compile_cfgs[1]
|
159
99
|
|
160
|
-
|
161
|
-
|
162
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
|
163
|
-
else:
|
164
|
-
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
|
165
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
100
|
+
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
|
101
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
166
102
|
|
167
103
|
enc_example_inputs[3].fill_(0)
|
168
104
|
dec_example_inputs[4].fill_(-1)
|
169
105
|
|
170
|
-
enc_scripted_model = torch.jit.trace(
|
171
|
-
dec_scripted_model = torch.jit.trace(
|
106
|
+
enc_scripted_model = torch.jit.trace(wrapped_model.encoder, enc_example_inputs, check_trace=False)
|
107
|
+
dec_scripted_model = torch.jit.trace(wrapped_model.decoder, dec_example_inputs, check_trace=False)
|
172
108
|
|
173
109
|
enc_ir = rebel.torchscript_to_ir(
|
174
110
|
enc_scripted_model,
|
@@ -180,13 +116,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
180
116
|
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
181
117
|
name=dec_rbln_compile_config.mod_name,
|
182
118
|
)
|
183
|
-
dec_ir.decoder_batch_size = rbln_config.model_cfg["batch_size"]
|
184
119
|
|
185
120
|
connections = [
|
186
121
|
(enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
|
187
|
-
# (enc_ir.outputs[0], enc_ir.inputs[2]),
|
188
122
|
(dec_ir.outputs[1], dec_ir.inputs[5]),
|
189
123
|
]
|
124
|
+
|
190
125
|
compiled_model = rebel.compile(
|
191
126
|
enc_ir,
|
192
127
|
dec_ir,
|
@@ -209,14 +144,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
209
144
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
210
145
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
211
146
|
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
d_kv = model_config.d_kv
|
147
|
+
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
148
|
+
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
149
|
+
d_kv = (
|
150
|
+
model_config.d_kv
|
151
|
+
if hasattr(model_config, "d_kv")
|
152
|
+
else model_config.d_model // model_config.encoder_attention_heads
|
153
|
+
)
|
220
154
|
|
221
155
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
222
156
|
model_config, "max_position_embeddings", None
|
@@ -270,7 +204,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
270
204
|
],
|
271
205
|
"float32",
|
272
206
|
),
|
273
|
-
# int16 available?
|
274
207
|
("batch_idx", [], "int32"),
|
275
208
|
]
|
276
209
|
|
@@ -281,7 +214,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
281
214
|
(
|
282
215
|
"cache_position",
|
283
216
|
[rbln_batch_size, 1],
|
284
|
-
# [],
|
285
217
|
"int32",
|
286
218
|
),
|
287
219
|
("batch_position", [], "int32"),
|
@@ -346,35 +278,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
346
278
|
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
347
279
|
]
|
348
280
|
|
281
|
+
def can_generate(self):
|
282
|
+
return True
|
283
|
+
|
284
|
+
def get_encoder(self):
|
285
|
+
return self.encoder
|
286
|
+
|
287
|
+
def get_decoder(self):
|
288
|
+
return self.decoder
|
289
|
+
|
349
290
|
def prepare_inputs_for_generation(
|
350
291
|
self,
|
351
292
|
input_ids,
|
352
|
-
past_key_values=None,
|
353
293
|
attention_mask=None,
|
354
294
|
decoder_attention_mask=None,
|
355
295
|
**kwargs,
|
356
296
|
):
|
357
|
-
past_cache_length = past_key_values
|
358
|
-
if past_cache_length == 0:
|
359
|
-
cache_pos = []
|
360
|
-
for i in range(input_ids.shape[0]):
|
361
|
-
cache_pos.append([0])
|
362
|
-
cache_position = torch.tensor(cache_pos, dtype=torch.int32)
|
363
|
-
|
364
|
-
max_seq_len = self.dec_max_seq_len
|
365
297
|
cur_seq_len = input_ids.shape[-1]
|
298
|
+
cache_position = cur_seq_len - 1
|
299
|
+
max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
366
300
|
decoder_batch_size = input_ids.shape[0]
|
367
301
|
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
368
|
-
# In greedy decoding
|
369
302
|
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
370
303
|
decoder_attention_mask[:, :cur_seq_len] = 1
|
371
|
-
|
372
|
-
for i in range(input_ids.shape[0]):
|
373
|
-
cache_pos.append([cur_seq_len - 1])
|
374
|
-
cache_position = torch.tensor(cache_pos, dtype=torch.int32)
|
304
|
+
|
375
305
|
return {
|
376
306
|
"decoder_input_ids": input_ids,
|
377
|
-
"past_key_values": past_key_values,
|
378
307
|
"attention_mask": attention_mask.to(torch.float32),
|
379
308
|
"decoder_attention_mask": decoder_attention_mask,
|
380
309
|
"cache_position": cache_position,
|
@@ -383,41 +312,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
383
312
|
def forward(
|
384
313
|
self,
|
385
314
|
input_ids: torch.LongTensor = None,
|
386
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
|
387
|
-
batch_idx: Optional[torch.LongTensor] = None,
|
388
|
-
enc_lengths: List[int] = None,
|
315
|
+
cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
|
389
316
|
**kwargs,
|
390
317
|
) -> Tuple[torch.FloatTensor]:
|
391
318
|
# common decoder
|
392
|
-
|
393
|
-
|
394
|
-
return output
|
395
|
-
|
396
|
-
# vllm & encoder
|
397
|
-
if batch_idx is not None:
|
398
|
-
enc_attention_mask = self.enc_attention_mask.clone()
|
399
|
-
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
400
|
-
padding_need = self.enc_max_seq_len - input_ids.shape[-1]
|
401
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
402
|
-
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
403
|
-
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
404
|
-
logits[0][0][-1] = 1
|
405
|
-
# vllm & decoder
|
406
|
-
else:
|
407
|
-
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
408
|
-
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
409
|
-
|
410
|
-
enc_attention_mask = self.dec_enc_attention_mask.clone()
|
411
|
-
dec_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
412
|
-
for batch_idx in range(self.batch_size):
|
413
|
-
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
414
|
-
|
415
|
-
logits = self._forward_decoder(
|
416
|
-
attention_mask=enc_attention_mask,
|
417
|
-
decoder_input_ids=input_ids,
|
418
|
-
decoder_attention_mask=dec_attention_mask,
|
419
|
-
cache_position=cache_position,
|
420
|
-
).logits
|
319
|
+
cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
|
320
|
+
logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
|
421
321
|
|
422
322
|
return Seq2SeqLMOutput(
|
423
323
|
logits=logits,
|
@@ -446,25 +346,58 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
446
346
|
|
447
347
|
return Seq2SeqLMOutput(logits=lm_logits)
|
448
348
|
|
349
|
+
def vllm_forward(
|
350
|
+
self,
|
351
|
+
input_ids: torch.LongTensor = None,
|
352
|
+
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
353
|
+
batch_idx: Optional[torch.LongTensor] = None,
|
354
|
+
enc_lengths: List[int] = None, # vllm return current attention_mask length
|
355
|
+
**kwargs,
|
356
|
+
) -> Tuple[torch.FloatTensor]:
|
357
|
+
# When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
|
358
|
+
# encoder
|
359
|
+
if batch_idx is not None:
|
360
|
+
enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
|
361
|
+
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
362
|
+
padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
|
363
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
364
|
+
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
365
|
+
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
366
|
+
logits[0][0][-1] = 1
|
367
|
+
# decoder
|
368
|
+
else:
|
369
|
+
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
370
|
+
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
371
|
+
|
372
|
+
enc_attention_mask = torch.zeros(
|
373
|
+
self.rbln_config.model_cfg["batch_size"],
|
374
|
+
self.rbln_config.model_cfg["enc_max_seq_len"],
|
375
|
+
dtype=torch.float32,
|
376
|
+
)
|
377
|
+
dec_attention_mask = torch.zeros(
|
378
|
+
self.rbln_config.model_cfg["batch_size"],
|
379
|
+
self.rbln_config.model_cfg["dec_max_seq_len"],
|
380
|
+
dtype=torch.float32,
|
381
|
+
)
|
382
|
+
for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
383
|
+
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
384
|
+
|
385
|
+
logits = self._forward_decoder(
|
386
|
+
attention_mask=enc_attention_mask,
|
387
|
+
decoder_input_ids=input_ids,
|
388
|
+
decoder_attention_mask=dec_attention_mask,
|
389
|
+
cache_position=cache_position,
|
390
|
+
).logits
|
391
|
+
|
392
|
+
return Seq2SeqLMOutput(logits=logits)
|
393
|
+
|
449
394
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
450
395
|
self,
|
451
396
|
inputs_tensor: torch.Tensor,
|
452
397
|
model_kwargs,
|
453
398
|
model_input_name: Optional[str] = None,
|
454
|
-
|
455
|
-
**kwargs,
|
399
|
+
generation_config: Optional[GenerationConfig] = None,
|
456
400
|
) -> Dict[str, Any]:
|
457
|
-
########## thkim change start ###################
|
458
|
-
# padding input_ids & attention_mask regardless of user's tokenizer usage
|
459
|
-
batch_size, input_len = inputs_tensor.shape
|
460
|
-
inputs_tensor = torch.nn.functional.pad(
|
461
|
-
inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
|
462
|
-
)
|
463
|
-
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
464
|
-
model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
|
465
|
-
)
|
466
|
-
########## thkim change end ###################
|
467
|
-
|
468
401
|
# 1. get encoder
|
469
402
|
encoder = self.get_encoder()
|
470
403
|
|
@@ -482,18 +415,26 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
482
415
|
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
483
416
|
}
|
484
417
|
|
418
|
+
batch_size, input_len = inputs_tensor.shape
|
419
|
+
inputs_tensor = torch.nn.functional.pad(
|
420
|
+
inputs_tensor,
|
421
|
+
(0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len),
|
422
|
+
value=self.rbln_config.model_cfg["pad_token_id"],
|
423
|
+
)
|
424
|
+
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
425
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.model_cfg["enc_max_seq_len"] - input_len)
|
426
|
+
)
|
427
|
+
|
485
428
|
# 3. make sure that encoder returns `ModelOutput`
|
486
429
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
487
430
|
encoder_kwargs["return_dict"] = True
|
488
|
-
encoder_kwargs[
|
431
|
+
encoder_kwargs["output_hidden_states"] = False
|
432
|
+
encoder_kwargs["output_attentions"] = False
|
433
|
+
|
489
434
|
for b in range(batch_size):
|
490
435
|
batch_idx = torch.tensor(b, dtype=torch.int32)
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
cb_inputs["output_attentions"] = False
|
495
|
-
cb_inputs["input_ids"] = encoder_kwargs["input_ids"][b].unsqueeze(0)
|
496
|
-
cb_inputs["attention_mask"] = encoder_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
497
|
-
model_kwargs["encoder_outputs"] = encoder(**cb_inputs, batch_idx=batch_idx)
|
436
|
+
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
437
|
+
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
438
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_idx=batch_idx)
|
498
439
|
|
499
440
|
return model_kwargs
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable
|
26
|
+
|
27
|
+
from transformers import T5ForConditionalGeneration
|
28
|
+
|
29
|
+
from ....modeling_config import RBLNConfig
|
30
|
+
from ....utils.logging import get_logger
|
31
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
32
|
+
from .t5_architecture import T5Wrapper
|
33
|
+
|
34
|
+
|
35
|
+
logger = get_logger()
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import PreTrainedModel
|
39
|
+
|
40
|
+
|
41
|
+
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
42
|
+
@classmethod
|
43
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
44
|
+
return T5Wrapper(model)
|
45
|
+
|
46
|
+
def __getattr__(self, __name: str) -> Any:
|
47
|
+
def redirect(func):
|
48
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
49
|
+
|
50
|
+
val = getattr(T5ForConditionalGeneration, __name)
|
51
|
+
|
52
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
53
|
+
return redirect(val)
|
54
|
+
|
55
|
+
return val
|
@@ -43,6 +43,12 @@ if TYPE_CHECKING:
|
|
43
43
|
from transformers import T5ForConditionalGeneration
|
44
44
|
|
45
45
|
|
46
|
+
class T5Wrapper:
|
47
|
+
def __init__(self, model):
|
48
|
+
self.encoder = T5EncoderWrapper(model)
|
49
|
+
self.decoder = T5DecoderWrapper(model)
|
50
|
+
|
51
|
+
|
46
52
|
class T5Encoder(T5Stack):
|
47
53
|
def forward(
|
48
54
|
self,
|
@@ -122,19 +128,26 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
122
128
|
)
|
123
129
|
self.encoder_max_length = None
|
124
130
|
self.decoder_max_length = None
|
125
|
-
self.decoder_batch_size = 1
|
126
131
|
|
127
|
-
def forward(
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
+
def forward(
|
133
|
+
self,
|
134
|
+
input_ids: torch.Tensor,
|
135
|
+
attention_mask: torch.Tensor,
|
136
|
+
cross_key_value: torch.Tensor = None,
|
137
|
+
batch_idx: torch.Tensor = None,
|
138
|
+
) -> torch.Tensor:
|
132
139
|
decoder_max_length = self.decoder_max_length or self.default_max_length
|
133
140
|
encoder_max_length = self.encoder_max_length or self.default_max_length
|
134
141
|
|
135
142
|
attn_layer = self.encoder.block[0].layer[0].SelfAttention
|
136
143
|
encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
|
137
|
-
encoder_outputs = T5Encoder.forward(
|
144
|
+
encoder_outputs = T5Encoder.forward(
|
145
|
+
self.encoder,
|
146
|
+
input_ids,
|
147
|
+
attention_mask,
|
148
|
+
encoder_position_bias,
|
149
|
+
batch_ids=torch.tensor(0, dtype=torch.int32),
|
150
|
+
)
|
138
151
|
|
139
152
|
attn_layer = self.decoder.block[0].layer[0].SelfAttention
|
140
153
|
decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
@@ -145,22 +158,14 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
145
158
|
|
146
159
|
dummy_past_key_value = []
|
147
160
|
for i in range(self.config.num_layers):
|
148
|
-
pkv_self_attn_key = torch.zeros(
|
149
|
-
|
150
|
-
)
|
151
|
-
|
152
|
-
decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
|
153
|
-
)
|
154
|
-
pkv_cross_attn_key = torch.zeros(
|
155
|
-
encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
|
156
|
-
)
|
157
|
-
pkv_cross_attn_value = torch.zeros(
|
158
|
-
encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
|
159
|
-
)
|
161
|
+
pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
|
162
|
+
pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
|
163
|
+
pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
|
164
|
+
pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
|
160
165
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
161
166
|
dummy_past_key_value.append(layer_pkv)
|
162
167
|
|
163
|
-
decoder_attention_mask = torch.zeros(
|
168
|
+
decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
|
164
169
|
decoder_attention_mask[:, :1] = 1
|
165
170
|
|
166
171
|
# Since first step of decoder has different graph to further step of it,
|
@@ -168,7 +173,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
168
173
|
# TODO(jongho): Separate first-step-decoder.
|
169
174
|
decoder_outputs = T5Decoder.forward(
|
170
175
|
self.decoder,
|
171
|
-
input_ids=torch.zeros(
|
176
|
+
input_ids=torch.zeros(1, 1, dtype=torch.int64),
|
172
177
|
attention_mask=decoder_attention_mask,
|
173
178
|
position_bias=decoder_position_bias,
|
174
179
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
@@ -187,7 +192,7 @@ class T5EncoderWrapper(torch.nn.Module):
|
|
187
192
|
cross_kv_cache.append(past_key_values[i][3])
|
188
193
|
cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
|
189
194
|
|
190
|
-
cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx+1)
|
195
|
+
cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
|
191
196
|
|
192
197
|
return cross_key_value
|
193
198
|
|
@@ -240,6 +245,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
240
245
|
attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
|
241
246
|
_decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
242
247
|
|
248
|
+
# position_bias need to compute with batch (for cb)
|
243
249
|
batch_decoder_position_bias = []
|
244
250
|
for i in range(input_ids.shape[0]):
|
245
251
|
batch_position_bias = _decoder_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
|
@@ -259,7 +265,7 @@ class T5DecoderWrapper(torch.nn.Module):
|
|
259
265
|
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
260
266
|
past_key_values=kv_cache,
|
261
267
|
cache_position=cache_position,
|
262
|
-
batch_ids=rbln_batch_position
|
268
|
+
batch_ids=rbln_batch_position,
|
263
269
|
)
|
264
270
|
|
265
271
|
past_key_values = decoder_outputs.past_key_values
|
@@ -312,7 +318,7 @@ class _T5Attention(T5Attention):
|
|
312
318
|
value_states = shape(self.v(hidden_states), batch_size)
|
313
319
|
else:
|
314
320
|
# cross-attn
|
315
|
-
if cache_position.dim() == 0
|
321
|
+
if cache_position.dim() == 0:
|
316
322
|
key_states = shape(self.k(key_value_states), key_value_states.shape[0])
|
317
323
|
value_states = shape(self.v(key_value_states), key_value_states.shape[0])
|
318
324
|
past_key_value = key_states, value_states
|
@@ -331,18 +337,24 @@ class _T5Attention(T5Attention):
|
|
331
337
|
batch_value_states = value_states[b].unsqueeze(0)
|
332
338
|
|
333
339
|
if is_self_attn and past_key_value is not None:
|
334
|
-
batch_key_states =
|
335
|
-
|
340
|
+
batch_key_states = (
|
341
|
+
past_key_value[0][b]
|
342
|
+
.unsqueeze(0)
|
343
|
+
.slice_scatter(
|
344
|
+
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
345
|
+
)
|
336
346
|
)
|
337
|
-
batch_value_states =
|
338
|
-
|
347
|
+
batch_value_states = (
|
348
|
+
past_key_value[1][b]
|
349
|
+
.unsqueeze(0)
|
350
|
+
.slice_scatter(
|
351
|
+
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
352
|
+
)
|
339
353
|
)
|
340
354
|
|
341
355
|
scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
|
342
356
|
scores += position_bias[b]
|
343
|
-
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
344
|
-
scores
|
345
|
-
)
|
357
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
346
358
|
attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
|
347
359
|
all_key_states.append(batch_key_states)
|
348
360
|
all_value_states.append(batch_value_states)
|
@@ -371,7 +383,9 @@ class _T5Attention(T5Attention):
|
|
371
383
|
scores
|
372
384
|
) # (batch_size, n_heads, seq_length, key_length)
|
373
385
|
|
374
|
-
attn_output = unshape(
|
386
|
+
attn_output = unshape(
|
387
|
+
torch.matmul(attn_weights, value_states), batch_size
|
388
|
+
) # (batch_size, seq_length, dim)
|
375
389
|
|
376
390
|
attn_output = self.o(attn_output)
|
377
391
|
present_key_value = (key_states, value_states)
|