optimum-rbln 0.7.2rc1__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +9 -4
- optimum/rbln/diffusers/__init__.py +8 -0
- optimum/rbln/diffusers/modeling_diffusers.py +103 -109
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
- optimum/rbln/diffusers/pipelines/__init__.py +8 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
- optimum/rbln/modeling.py +4 -1
- optimum/rbln/modeling_base.py +16 -3
- optimum/rbln/ops/__init__.py +6 -2
- optimum/rbln/ops/attn.py +94 -85
- optimum/rbln/ops/flash_attn.py +46 -25
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
- optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
- optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.2rc1.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -32,11 +32,13 @@ if TYPE_CHECKING:
|
|
32
32
|
|
33
33
|
|
34
34
|
class PhiWrapper(DecoderOnlyWrapper):
|
35
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
|
35
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM", max_seq_len: int):
|
36
36
|
new_layers = []
|
37
37
|
for layer in causal_lm.model.layers:
|
38
38
|
if self.attn_impl == "eager":
|
39
|
-
new_self_attn = PhiAttention(
|
39
|
+
new_self_attn = PhiAttention(
|
40
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
41
|
+
)
|
40
42
|
elif self.attn_impl == "flash_attn":
|
41
43
|
raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
|
42
44
|
else:
|
@@ -81,30 +83,30 @@ class PhiLayer(DecoderOnlyLayer):
|
|
81
83
|
hidden_states: torch.Tensor,
|
82
84
|
attention_mask: torch.Tensor,
|
83
85
|
seq_positions: torch.LongTensor,
|
84
|
-
batch_position: torch.Tensor,
|
85
86
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
86
87
|
cos: Optional[torch.Tensor] = None,
|
87
88
|
sin: Optional[torch.Tensor] = None,
|
89
|
+
block_tables: Optional[torch.Tensor] = None,
|
88
90
|
):
|
89
91
|
residual = hidden_states
|
90
92
|
|
91
93
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
92
94
|
|
93
|
-
|
95
|
+
attn_output = self.self_attn(
|
94
96
|
hidden_states=hidden_states,
|
95
97
|
attention_mask=attention_mask,
|
96
98
|
seq_positions=seq_positions,
|
97
|
-
batch_position=batch_position,
|
98
99
|
past_key_values=past_key_values,
|
99
100
|
cos=cos,
|
100
101
|
sin=sin,
|
102
|
+
block_tables=block_tables,
|
101
103
|
)
|
102
104
|
|
103
105
|
feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
|
104
106
|
|
105
|
-
hidden_states =
|
107
|
+
hidden_states = attn_output + feed_forward_hidden_states + residual
|
106
108
|
|
107
|
-
return hidden_states
|
109
|
+
return hidden_states
|
108
110
|
|
109
111
|
|
110
112
|
class PhiModel(DecoderOnlyModel):
|
@@ -50,11 +50,18 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
50
50
|
runtime: rebel.Runtime,
|
51
51
|
batch_size: int,
|
52
52
|
dec_max_seq_len: int,
|
53
|
+
support_paged_causal_attn: Optional[bool] = None,
|
54
|
+
use_attention_mask: Optional[bool] = None,
|
53
55
|
**kwargs: Any,
|
54
56
|
) -> None:
|
55
57
|
super().__init__(runtime, **kwargs)
|
56
58
|
self.batch_size = batch_size
|
57
59
|
self.dec_max_seq_len = dec_max_seq_len
|
60
|
+
self.use_attention_mask = use_attention_mask
|
61
|
+
if support_paged_causal_attn:
|
62
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
63
|
+
else:
|
64
|
+
self.default_block_tables = None
|
58
65
|
|
59
66
|
def forward(
|
60
67
|
self,
|
@@ -62,6 +69,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
62
69
|
attention_mask: Optional[torch.FloatTensor] = None,
|
63
70
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
64
71
|
cache_position: Optional[torch.Tensor] = None,
|
72
|
+
block_tables: Optional[torch.Tensor] = None,
|
65
73
|
**kwargs,
|
66
74
|
) -> Tuple[torch.FloatTensor]:
|
67
75
|
batch_size = decoder_input_ids.shape[0]
|
@@ -73,19 +81,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
73
81
|
if batch_size != cache_position.shape[0]:
|
74
82
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
75
83
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
84
|
+
if self.use_attention_mask:
|
85
|
+
for b_idx in range(self.batch_size):
|
86
|
+
decoding_step = cache_position[b_idx].item()
|
87
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
88
|
+
raise ValueError(
|
89
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
90
|
+
)
|
91
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
92
|
+
|
93
|
+
if block_tables is None:
|
94
|
+
block_tables = self.default_block_tables
|
83
95
|
|
84
96
|
lm_logits = super().forward(
|
85
97
|
decoder_input_ids,
|
86
|
-
decoder_attention_mask,
|
98
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
87
99
|
attention_mask,
|
88
100
|
cache_position,
|
101
|
+
block_tables=block_tables,
|
89
102
|
)
|
90
103
|
|
91
104
|
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -106,16 +119,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
106
119
|
|
107
120
|
main_input_name = "input_ids"
|
108
121
|
auto_model_class = AutoModelForSeq2SeqLM
|
122
|
+
support_paged_causal_attn = None
|
109
123
|
|
110
124
|
def __post_init__(self, **kwargs):
|
111
125
|
batch_size = self.rbln_config.model_cfg["batch_size"]
|
112
126
|
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
127
|
+
self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
|
128
|
+
|
113
129
|
self.encoder = RBLNRuntimeEncoder(
|
114
130
|
runtime=self.model[0],
|
115
131
|
main_input_name="input_ids",
|
116
132
|
)
|
117
133
|
self.decoder = RBLNRuntimeDecoder(
|
118
|
-
runtime=self.model[1],
|
134
|
+
runtime=self.model[1],
|
135
|
+
main_input_name="input_ids",
|
136
|
+
batch_size=batch_size,
|
137
|
+
dec_max_seq_len=dec_max_seq_len,
|
138
|
+
support_paged_causal_attn=self.support_paged_causal_attn,
|
139
|
+
use_attention_mask=self.use_attention_mask,
|
119
140
|
)
|
120
141
|
|
121
142
|
@classmethod
|
@@ -172,6 +193,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
172
193
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
173
194
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
174
195
|
|
196
|
+
if cls.support_paged_causal_attn:
|
197
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
198
|
+
if rbln_use_attention_mask is None:
|
199
|
+
rbln_use_attention_mask = False
|
200
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
201
|
+
if rbln_npu == "RBLN-CA02":
|
202
|
+
rbln_use_attention_mask = True
|
203
|
+
else:
|
204
|
+
rbln_use_attention_mask = True
|
205
|
+
|
175
206
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
176
207
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
177
208
|
d_kv = (
|
@@ -232,12 +263,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
232
263
|
],
|
233
264
|
"float32",
|
234
265
|
),
|
235
|
-
("
|
266
|
+
("block_tables", [1], "int16"),
|
236
267
|
]
|
237
268
|
|
238
269
|
dec_input_info = [
|
239
270
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
240
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
241
271
|
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
242
272
|
(
|
243
273
|
"cache_position",
|
@@ -275,6 +305,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
275
305
|
for i in range(n_layer * 2)
|
276
306
|
]
|
277
307
|
)
|
308
|
+
|
309
|
+
if cls.support_paged_causal_attn:
|
310
|
+
dec_input_info.insert(3, ("block_tables", [rbln_batch_size, 1], "int16"))
|
311
|
+
if rbln_use_attention_mask:
|
312
|
+
dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
|
313
|
+
|
278
314
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
279
315
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
280
316
|
|
@@ -290,6 +326,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
290
326
|
"dec_max_seq_len": rbln_dec_max_seq_len,
|
291
327
|
"batch_size": rbln_batch_size,
|
292
328
|
"pad_token_id": rbln_pad_token_id,
|
329
|
+
"use_attention_mask": rbln_use_attention_mask,
|
293
330
|
}
|
294
331
|
)
|
295
332
|
|
@@ -400,9 +437,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
400
437
|
encoder_kwargs["output_attentions"] = False
|
401
438
|
|
402
439
|
for b in range(batch_size):
|
403
|
-
|
440
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
404
441
|
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
405
442
|
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
406
|
-
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs,
|
443
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
407
444
|
|
408
445
|
return model_kwargs
|
@@ -12,13 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Tuple
|
15
|
+
from typing import Optional, Tuple
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import
|
21
|
+
from ....ops import (
|
22
|
+
register_rbln_custom_cache_update,
|
23
|
+
register_rbln_custom_paged_attention,
|
24
|
+
register_rbln_custom_paged_causal_attention,
|
25
|
+
)
|
22
26
|
|
23
27
|
|
24
28
|
logger = logging.get_logger(__name__)
|
@@ -87,7 +91,7 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
87
91
|
input_ids: torch.Tensor,
|
88
92
|
attention_mask: torch.Tensor,
|
89
93
|
cross_key_values: torch.Tensor,
|
90
|
-
|
94
|
+
b_idx: torch.Tensor,
|
91
95
|
) -> Tuple[torch.Tensor]:
|
92
96
|
# 1. get encoder last_hidden_states
|
93
97
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
@@ -110,11 +114,9 @@ class Seq2SeqEncoderWrapper(nn.Module):
|
|
110
114
|
|
111
115
|
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
112
116
|
batch_axis = torch.tensor(1, dtype=torch.int16)
|
113
|
-
|
114
|
-
cross_key_values, cross_kv, batch_position, batch_axis
|
115
|
-
)
|
117
|
+
enc_out = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, b_idx[0], batch_axis)
|
116
118
|
|
117
|
-
return
|
119
|
+
return enc_out
|
118
120
|
|
119
121
|
|
120
122
|
class Seq2SeqDecoderWrapper(nn.Module):
|
@@ -131,9 +133,10 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
131
133
|
**kwargs: Additional arguments for decoder configuration.
|
132
134
|
"""
|
133
135
|
|
134
|
-
def __init__(self, model: nn.Module, **kwargs):
|
136
|
+
def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
|
135
137
|
super().__init__()
|
136
138
|
self.config = model.config
|
139
|
+
self.use_attention_mask = use_attention_mask
|
137
140
|
self.__post_init__(model, **kwargs)
|
138
141
|
|
139
142
|
def __post_init__(self, model: nn.Module, **kwargs):
|
@@ -143,7 +146,11 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
143
146
|
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
144
147
|
by subclasses to modify or add custom attributes as necessary.
|
145
148
|
"""
|
146
|
-
|
149
|
+
if self.use_attention_mask:
|
150
|
+
register_rbln_custom_paged_attention()
|
151
|
+
else:
|
152
|
+
register_rbln_custom_paged_causal_attention()
|
153
|
+
|
147
154
|
self.num_layers = self.config.decoder_layers
|
148
155
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
|
149
156
|
|
@@ -160,13 +167,23 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
160
167
|
|
161
168
|
def forward(
|
162
169
|
self,
|
163
|
-
|
164
|
-
attention_mask: torch.Tensor,
|
165
|
-
encoder_attention_mask: torch.Tensor,
|
166
|
-
cache_position: torch.Tensor,
|
167
|
-
cross_kv_cache: torch.Tensor,
|
168
|
-
*self_kv_cache: torch.Tensor,
|
170
|
+
*args,
|
169
171
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
172
|
+
if self.use_attention_mask:
|
173
|
+
(
|
174
|
+
input_ids,
|
175
|
+
attention_mask,
|
176
|
+
encoder_attention_mask,
|
177
|
+
cache_position,
|
178
|
+
block_tables,
|
179
|
+
cross_kv_cache,
|
180
|
+
*self_kv_cache,
|
181
|
+
) = args
|
182
|
+
|
183
|
+
else:
|
184
|
+
attention_mask = None
|
185
|
+
(input_ids, encoder_attention_mask, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
|
186
|
+
|
170
187
|
self_past_key_values = ()
|
171
188
|
cross_past_key_values = ()
|
172
189
|
for i in range(0, self.num_layers * 2, 2):
|
@@ -174,18 +191,17 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
174
191
|
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
175
192
|
|
176
193
|
# decode
|
177
|
-
lm_logits
|
194
|
+
lm_logits = self.conditional_generation(
|
178
195
|
input_ids=input_ids,
|
179
196
|
attention_mask=attention_mask,
|
180
197
|
encoder_attention_mask=encoder_attention_mask,
|
181
198
|
self_past_key_values=self_past_key_values,
|
182
199
|
cross_past_key_values=cross_past_key_values,
|
183
200
|
cache_position=cache_position,
|
201
|
+
block_tables=block_tables,
|
184
202
|
)
|
185
203
|
|
186
|
-
|
187
|
-
|
188
|
-
return outputs
|
204
|
+
return lm_logits
|
189
205
|
|
190
206
|
|
191
207
|
class Seq2SeqForConditionalGeneration(nn.Module):
|
@@ -228,14 +244,16 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
228
244
|
self_past_key_values,
|
229
245
|
cross_past_key_values,
|
230
246
|
cache_position,
|
247
|
+
block_tables: Optional[torch.Tensor] = None,
|
231
248
|
):
|
232
|
-
hidden_states
|
249
|
+
hidden_states = self.decoder(
|
233
250
|
input_ids=input_ids,
|
234
251
|
attention_mask=attention_mask,
|
235
252
|
encoder_attention_mask=encoder_attention_mask,
|
236
253
|
self_past_key_values=self_past_key_values,
|
237
254
|
cross_past_key_values=cross_past_key_values,
|
238
255
|
cache_position=cache_position,
|
256
|
+
block_tables=block_tables,
|
239
257
|
)
|
240
258
|
|
241
259
|
if self.has_rescaling and self.config.tie_word_embeddings:
|
@@ -243,7 +261,7 @@ class Seq2SeqForConditionalGeneration(nn.Module):
|
|
243
261
|
|
244
262
|
lm_logits = self.lm_head(hidden_states)
|
245
263
|
|
246
|
-
return lm_logits
|
264
|
+
return lm_logits
|
247
265
|
|
248
266
|
|
249
267
|
class Seq2SeqDecoder(torch.nn.Module):
|
@@ -292,6 +310,7 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
292
310
|
self_past_key_values: torch.Tensor,
|
293
311
|
cross_past_key_values: torch.Tensor,
|
294
312
|
cache_position: torch.Tensor,
|
313
|
+
block_tables: Optional[torch.Tensor] = None,
|
295
314
|
):
|
296
315
|
# embedding
|
297
316
|
hidden_states = self.get_embedding()(input_ids)
|
@@ -303,24 +322,23 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
303
322
|
hidden_states = self.apply_position_embedding(hidden_states, cache_position)
|
304
323
|
|
305
324
|
# iterate decoder_layer
|
306
|
-
self_present_key_values = ()
|
307
325
|
for decoder_layer, self_past_key_value, cross_past_key_value in zip(
|
308
326
|
self.layers, self_past_key_values, cross_past_key_values
|
309
327
|
):
|
310
|
-
hidden_states
|
328
|
+
hidden_states = decoder_layer(
|
311
329
|
hidden_states,
|
312
330
|
attention_mask=attention_mask,
|
313
331
|
encoder_attention_mask=encoder_attention_mask,
|
314
332
|
self_past_key_value=self_past_key_value,
|
315
333
|
cross_past_key_value=cross_past_key_value,
|
316
334
|
cache_position=cache_position,
|
335
|
+
block_tables=block_tables,
|
317
336
|
)
|
318
|
-
self_present_key_values += self_present_key_value
|
319
337
|
|
320
338
|
if self.final_layer_norm is not None:
|
321
339
|
hidden_states = self.final_layer_norm(hidden_states)
|
322
340
|
|
323
|
-
return hidden_states
|
341
|
+
return hidden_states
|
324
342
|
|
325
343
|
|
326
344
|
class Seq2SeqDecoderLayer(torch.nn.Module):
|
@@ -373,17 +391,19 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
373
391
|
self_past_key_value: Tuple[torch.Tensor],
|
374
392
|
cross_past_key_value: Tuple[torch.Tensor],
|
375
393
|
cache_position: torch.Tensor,
|
394
|
+
block_tables: Optional[torch.Tensor] = None,
|
376
395
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
377
396
|
dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
|
378
397
|
|
379
398
|
# Self Attention Block
|
380
399
|
residual = hidden_states
|
381
400
|
hidden_states = self.pre_self_attn_layer_norm(hidden_states)
|
382
|
-
hidden_states
|
401
|
+
hidden_states = self.self_attn(
|
383
402
|
hidden_states=hidden_states,
|
384
403
|
past_key_value=self_past_key_value,
|
385
404
|
attention_mask=attention_mask,
|
386
405
|
cache_position=cache_position,
|
406
|
+
block_tables=block_tables,
|
387
407
|
)
|
388
408
|
hidden_states = residual + hidden_states
|
389
409
|
hidden_states = self.post_self_attn_layer_norm(hidden_states)
|
@@ -403,14 +423,14 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
403
423
|
# Feed-Forward Block
|
404
424
|
hidden_states = self.ff_layer(hidden_states)
|
405
425
|
|
406
|
-
return hidden_states
|
426
|
+
return hidden_states
|
407
427
|
|
408
428
|
|
409
429
|
class Seq2SeqSelfAttention(nn.Module):
|
410
|
-
def __init__(self, attn):
|
430
|
+
def __init__(self, attn, **kwargs):
|
411
431
|
super().__init__()
|
412
432
|
self._original_mod = attn
|
413
|
-
self.__post_init__()
|
433
|
+
self.__post_init__(**kwargs)
|
414
434
|
|
415
435
|
def __post_init__(self, **kwargs):
|
416
436
|
"""
|
@@ -442,6 +462,7 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
442
462
|
past_key_value: Tuple[torch.Tensor],
|
443
463
|
attention_mask: torch.Tensor,
|
444
464
|
cache_position: torch.Tensor,
|
465
|
+
block_tables: Optional[torch.Tensor] = None,
|
445
466
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
446
467
|
bsz, tgt_len, _ = hidden_states.size()
|
447
468
|
|
@@ -450,23 +471,26 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
450
471
|
key_states = self._shape(key_states, -1, bsz)
|
451
472
|
value_states = self._shape(value_states, -1, bsz)
|
452
473
|
|
453
|
-
|
474
|
+
block_size = past_key_value[0].shape[-2]
|
475
|
+
args = [
|
454
476
|
query_states,
|
455
477
|
key_states,
|
456
478
|
value_states,
|
457
|
-
attention_mask.unsqueeze(
|
458
|
-
2
|
459
|
-
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
460
479
|
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
461
480
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
462
481
|
cache_position,
|
463
482
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
464
|
-
|
483
|
+
block_tables,
|
484
|
+
block_size,
|
485
|
+
]
|
486
|
+
if attention_mask is not None:
|
487
|
+
args.insert(3, attention_mask.unsqueeze(2))
|
488
|
+
|
489
|
+
attn_output = self.attn_decode(*args)
|
465
490
|
|
466
491
|
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
467
492
|
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
468
493
|
|
469
494
|
attn_output = self.out_proj(attn_output)
|
470
|
-
present_key_value = (key_states, value_states)
|
471
495
|
|
472
|
-
return attn_output
|
496
|
+
return attn_output
|
@@ -120,7 +120,7 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
120
120
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
121
121
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
122
122
|
|
123
|
-
signature_params = inspect.signature(cls.
|
123
|
+
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
124
124
|
|
125
125
|
if rbln_model_input_names is None:
|
126
126
|
for tokenizer in preprocessors:
|
@@ -189,6 +189,8 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
189
189
|
|
190
190
|
|
191
191
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
192
|
+
support_causal_paged_attn = False
|
193
|
+
|
192
194
|
@classmethod
|
193
195
|
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
194
196
|
enc_max_seq_len = rbln_config.model_cfg["enc_max_seq_len"]
|
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
from transformers.utils import logging
|
20
20
|
|
21
|
-
from ....ops import
|
21
|
+
from ....ops import register_rbln_custom_add_softmax_attention
|
22
22
|
from ..seq2seq.seq2seq_architecture import (
|
23
23
|
Seq2SeqDecoder,
|
24
24
|
Seq2SeqDecoderLayer,
|
@@ -55,7 +55,7 @@ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
55
55
|
|
56
56
|
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
57
57
|
def __post_init__(self, model, dec_max_seq_len: int = None):
|
58
|
-
|
58
|
+
register_rbln_custom_add_softmax_attention()
|
59
59
|
self.num_layers = self.config.num_layers
|
60
60
|
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
61
61
|
|
@@ -71,6 +71,34 @@ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
71
71
|
|
72
72
|
return new_model
|
73
73
|
|
74
|
+
def forward(
|
75
|
+
self,
|
76
|
+
input_ids,
|
77
|
+
attention_mask,
|
78
|
+
encoder_attention_mask,
|
79
|
+
cache_position,
|
80
|
+
cross_kv_cache,
|
81
|
+
*self_kv_cache,
|
82
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
83
|
+
self_past_key_values = ()
|
84
|
+
cross_past_key_values = ()
|
85
|
+
|
86
|
+
for i in range(0, self.num_layers * 2, 2):
|
87
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
88
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
89
|
+
|
90
|
+
# decode
|
91
|
+
lm_logits = self.conditional_generation(
|
92
|
+
input_ids=input_ids,
|
93
|
+
attention_mask=attention_mask,
|
94
|
+
encoder_attention_mask=encoder_attention_mask,
|
95
|
+
self_past_key_values=self_past_key_values,
|
96
|
+
cross_past_key_values=cross_past_key_values,
|
97
|
+
cache_position=cache_position,
|
98
|
+
)
|
99
|
+
|
100
|
+
return lm_logits
|
101
|
+
|
74
102
|
|
75
103
|
class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
76
104
|
has_rescaling = True
|
@@ -134,7 +162,7 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
134
162
|
self.out_proj = self._original_mod.o
|
135
163
|
self.num_heads = self._original_mod.n_heads
|
136
164
|
self.head_dim = self._original_mod.key_value_proj_dim
|
137
|
-
self.attn_decode = torch.ops.rbln_custom_ops.
|
165
|
+
self.attn_decode = torch.ops.rbln_custom_ops.add_softmax_attn_decode
|
138
166
|
|
139
167
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
140
168
|
query_states = self.q_proj(hidden_states)
|
@@ -142,6 +170,40 @@ class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
142
170
|
value_states = self.v_proj(hidden_states)
|
143
171
|
return query_states, key_states, value_states
|
144
172
|
|
173
|
+
def forward(
|
174
|
+
self,
|
175
|
+
hidden_states: torch.Tensor,
|
176
|
+
past_key_value: Tuple[torch.Tensor],
|
177
|
+
attention_mask: torch.Tensor,
|
178
|
+
cache_position: torch.Tensor,
|
179
|
+
**kwargs,
|
180
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
181
|
+
bsz, tgt_len, _ = hidden_states.size()
|
182
|
+
|
183
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
184
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
185
|
+
key_states = self._shape(key_states, -1, bsz)
|
186
|
+
value_states = self._shape(value_states, -1, bsz)
|
187
|
+
|
188
|
+
attn_output = self.attn_decode(
|
189
|
+
query_states,
|
190
|
+
key_states,
|
191
|
+
value_states,
|
192
|
+
attention_mask.unsqueeze(
|
193
|
+
2
|
194
|
+
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
195
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
196
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
197
|
+
cache_position,
|
198
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
199
|
+
)
|
200
|
+
|
201
|
+
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
202
|
+
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
203
|
+
|
204
|
+
attn_output = self.out_proj(attn_output)
|
205
|
+
return attn_output
|
206
|
+
|
145
207
|
|
146
208
|
class T5CrossAttention(nn.Module):
|
147
209
|
def __init__(self, attn):
|