optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/ops/__init__.py +4 -4
- optimum/rbln/ops/attn.py +44 -84
- optimum/rbln/ops/flash_attn.py +25 -25
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +79 -51
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +157 -34
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +7 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +7 -2
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +3 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -3
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +44 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +50 -19
- optimum/rbln/transformers/models/t5/modeling_t5.py +211 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +69 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +19 -24
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/RECORD +22 -22
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a3.dist-info}/licenses/LICENSE +0 -0
@@ -13,9 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
+
from collections import deque
|
16
17
|
from dataclasses import dataclass
|
17
18
|
from pathlib import Path
|
18
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
19
|
+
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import rebel
|
21
22
|
import torch
|
@@ -50,17 +51,28 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
50
51
|
phase: str,
|
51
52
|
batch_size: int,
|
52
53
|
dec_attn_mask: torch.Tensor,
|
54
|
+
block_tables: torch.Tensor,
|
55
|
+
free_block_pool: Deque,
|
56
|
+
kvcache_block_size: int,
|
53
57
|
use_attention_mask: bool,
|
58
|
+
attn_impl: str,
|
54
59
|
**kwargs: Any,
|
55
60
|
) -> None:
|
56
61
|
super().__init__(runtime, **kwargs)
|
57
62
|
self.phase = phase
|
58
63
|
self.batch_size = batch_size
|
59
64
|
|
65
|
+
# shared data structures between prefill and decode phase
|
60
66
|
self.use_attention_mask = use_attention_mask
|
61
67
|
|
62
68
|
# shared tensor between prefill and decode phase
|
63
69
|
self.dec_attn_mask = dec_attn_mask
|
70
|
+
self.block_tables = block_tables
|
71
|
+
self.free_block_pool = free_block_pool
|
72
|
+
|
73
|
+
self.kvcache_block_size = kvcache_block_size
|
74
|
+
self.empty_block = -1
|
75
|
+
self.attn_impl = attn_impl
|
64
76
|
|
65
77
|
if self.phase == "prefill":
|
66
78
|
vocab_size = kwargs.pop("vocab_size")
|
@@ -71,6 +83,75 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
71
83
|
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
72
84
|
)
|
73
85
|
|
86
|
+
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
|
87
|
+
"""
|
88
|
+
Manages and returns the KV cache block tables.
|
89
|
+
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
93
|
+
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
torch.Tensor: Updated block tables.
|
97
|
+
"""
|
98
|
+
|
99
|
+
NO_BLOCKS_ERROR = (
|
100
|
+
"No memory blocks are available for allocation."
|
101
|
+
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln."
|
102
|
+
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html)."
|
103
|
+
"Using vllm-rbln should fix this issue and enhance inference performance."
|
104
|
+
)
|
105
|
+
|
106
|
+
def update_block(batch_idx: int, block_idx: int):
|
107
|
+
"""
|
108
|
+
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
109
|
+
"""
|
110
|
+
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
111
|
+
if self.free_block_pool:
|
112
|
+
block = self.free_block_pool.popleft()
|
113
|
+
self.block_tables[batch_idx][block_idx] = block
|
114
|
+
else:
|
115
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
116
|
+
|
117
|
+
def replace_empty_block(block_tables: torch.Tensor):
|
118
|
+
"""
|
119
|
+
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
120
|
+
"""
|
121
|
+
if not torch.any(block_tables == self.empty_block):
|
122
|
+
return block_tables.clone()
|
123
|
+
elif self.free_block_pool:
|
124
|
+
_free_block = self.free_block_pool[0]
|
125
|
+
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
126
|
+
else:
|
127
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
128
|
+
|
129
|
+
if self.phase == "prefill":
|
130
|
+
# Track previously used blocks and return them to the free_block_pool and
|
131
|
+
# reset the current batch's block table to empty blocks
|
132
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
133
|
+
self.free_block_pool.extend(prev_blocks)
|
134
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
135
|
+
|
136
|
+
# Get the start (s) and end (e) positions from cache_position and
|
137
|
+
# iterate over the cache positions to allocate necessary blocks
|
138
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
139
|
+
for position in range(s, e + 1, self.kvcache_block_size):
|
140
|
+
block_idx = position // self.kvcache_block_size
|
141
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
142
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
143
|
+
update_block(batch_idx, block_idx)
|
144
|
+
|
145
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
146
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
147
|
+
else:
|
148
|
+
for b_idx in range(self.batch_size):
|
149
|
+
position = cache_position[b_idx][0].item()
|
150
|
+
block_idx = position // self.kvcache_block_size
|
151
|
+
update_block(b_idx, block_idx)
|
152
|
+
|
153
|
+
return replace_empty_block(self.block_tables)
|
154
|
+
|
74
155
|
def forward(
|
75
156
|
self,
|
76
157
|
input_ids: Optional[torch.LongTensor] = None,
|
@@ -78,6 +159,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
78
159
|
cache_position: torch.Tensor = None,
|
79
160
|
attention_mask: Optional[torch.Tensor] = None,
|
80
161
|
batch_idx: Optional[int] = None,
|
162
|
+
block_tables: Optional[torch.Tensor] = None,
|
81
163
|
):
|
82
164
|
if input_ids is None and inputs_embeds is None:
|
83
165
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -89,19 +171,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
89
171
|
else:
|
90
172
|
inputs = inputs_embeds
|
91
173
|
|
174
|
+
if block_tables is None:
|
175
|
+
block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
176
|
+
is_external_block_tables = False
|
177
|
+
else:
|
178
|
+
is_external_block_tables = True
|
179
|
+
|
92
180
|
if self.phase == "decode":
|
93
181
|
return self.decode_forward(
|
94
182
|
inputs,
|
95
183
|
cache_position,
|
184
|
+
block_tables,
|
185
|
+
is_external_block_tables,
|
96
186
|
attention_mask=attention_mask,
|
97
187
|
)
|
98
188
|
else:
|
99
|
-
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
|
189
|
+
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
|
100
190
|
|
101
191
|
def decode_forward(
|
102
192
|
self,
|
103
193
|
inputs: torch.Tensor,
|
104
194
|
cache_position: torch.Tensor = None,
|
195
|
+
block_tables: torch.Tensor = None,
|
196
|
+
is_external_block_tables: bool = None,
|
105
197
|
attention_mask: Optional[torch.Tensor] = None,
|
106
198
|
) -> torch.FloatTensor:
|
107
199
|
batch_size = inputs.shape[0]
|
@@ -120,7 +212,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
120
212
|
raise ValueError(
|
121
213
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
122
214
|
)
|
123
|
-
|
215
|
+
|
216
|
+
if is_external_block_tables:
|
217
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
218
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
219
|
+
else:
|
220
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
221
|
+
|
222
|
+
attention_mask = self.dec_attn_mask
|
124
223
|
|
125
224
|
attention_mask = self.dec_attn_mask
|
126
225
|
|
@@ -128,6 +227,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
128
227
|
inputs,
|
129
228
|
cache_position,
|
130
229
|
attention_mask if self.use_attention_mask else None,
|
230
|
+
block_tables,
|
131
231
|
)
|
132
232
|
|
133
233
|
return logits
|
@@ -138,6 +238,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
138
238
|
cache_position: torch.Tensor = None,
|
139
239
|
attention_mask: Optional[torch.Tensor] = None,
|
140
240
|
batch_idx: int = None,
|
241
|
+
block_tables: torch.Tensor = None,
|
242
|
+
is_external_block_tables: bool = None,
|
141
243
|
) -> torch.FloatTensor:
|
142
244
|
"""
|
143
245
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -145,11 +247,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
145
247
|
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
146
248
|
"""
|
147
249
|
|
148
|
-
if batch_idx is None or batch_idx >= self.batch_size:
|
149
|
-
raise RuntimeError(
|
150
|
-
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
151
|
-
)
|
152
|
-
|
153
250
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
154
251
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
155
252
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
@@ -207,33 +304,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
207
304
|
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
208
305
|
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
209
306
|
|
210
|
-
# Define
|
211
|
-
batch_position = torch.tensor(batch_idx, dtype=torch.int16)
|
307
|
+
# Define query position
|
212
308
|
query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
|
213
309
|
|
214
|
-
if self.use_attention_mask:
|
215
|
-
args = (
|
216
|
-
input_chunk,
|
217
|
-
cache_pos_chunk,
|
218
|
-
chunked_attention_mask,
|
219
|
-
batch_position,
|
220
|
-
query_position,
|
221
|
-
)
|
222
|
-
else:
|
223
|
-
args = (
|
224
|
-
input_chunk,
|
225
|
-
cache_pos_chunk,
|
226
|
-
batch_position,
|
227
|
-
query_position,
|
228
|
-
)
|
229
310
|
# Forward pass for the current chunk
|
230
311
|
logits = super().forward(
|
231
|
-
|
312
|
+
input_chunk,
|
313
|
+
cache_pos_chunk,
|
314
|
+
chunked_attention_mask if self.use_attention_mask else None,
|
315
|
+
query_position,
|
316
|
+
block_tables,
|
232
317
|
out=out_buffers,
|
233
318
|
)
|
234
319
|
|
235
|
-
|
236
|
-
|
320
|
+
# Update decoder attention mask with processed KV-cache length from prefill phase
|
321
|
+
if not is_external_block_tables and self.use_attention_mask:
|
237
322
|
self.dec_attn_mask[batch_idx].fill_(0)
|
238
323
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
239
324
|
|
@@ -275,9 +360,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
275
360
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
276
361
|
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
277
362
|
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
363
|
+
self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
|
364
|
+
# FIXME get kvcache_num_blocks from compiled results.
|
365
|
+
self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
|
278
366
|
self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
|
279
|
-
|
367
|
+
attn_impl = self.rbln_config.model_cfg["attn_impl"]
|
280
368
|
main_input_name = self.main_input_name
|
369
|
+
|
281
370
|
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
282
371
|
main_input_name = "inputs_embeds"
|
283
372
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
@@ -291,7 +380,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
291
380
|
else:
|
292
381
|
self.embed_tokens = None
|
293
382
|
|
383
|
+
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
294
384
|
dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
385
|
+
block_tables = torch.zeros(
|
386
|
+
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
387
|
+
).fill_(-1)
|
388
|
+
free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
|
389
|
+
|
295
390
|
self.prefill_decoder = RBLNRuntimeModel(
|
296
391
|
runtime=self.model[0],
|
297
392
|
main_input_name=main_input_name,
|
@@ -299,10 +394,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
299
394
|
phase="prefill",
|
300
395
|
batch_size=self.batch_size,
|
301
396
|
dec_attn_mask=dec_attn_mask,
|
397
|
+
block_tables=block_tables,
|
398
|
+
free_block_pool=free_block_pool,
|
399
|
+
kvcache_block_size=self.kvcache_block_size,
|
302
400
|
vocab_size=self.config.vocab_size,
|
303
|
-
max_seq_len=self.max_seq_len,
|
304
401
|
prefill_chunk_size=self.prefill_chunk_size,
|
402
|
+
max_seq_len=self.max_seq_len,
|
305
403
|
use_attention_mask=self.use_attention_mask,
|
404
|
+
attn_impl=attn_impl,
|
306
405
|
)
|
307
406
|
self.decoder = RBLNRuntimeModel(
|
308
407
|
runtime=self.model[1],
|
@@ -311,7 +410,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
311
410
|
phase="decode",
|
312
411
|
batch_size=self.batch_size,
|
313
412
|
dec_attn_mask=dec_attn_mask,
|
413
|
+
block_tables=block_tables,
|
414
|
+
free_block_pool=free_block_pool,
|
415
|
+
kvcache_block_size=self.kvcache_block_size,
|
314
416
|
use_attention_mask=self.use_attention_mask,
|
417
|
+
attn_impl=attn_impl,
|
315
418
|
)
|
316
419
|
|
317
420
|
@classmethod
|
@@ -409,6 +512,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
409
512
|
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
410
513
|
wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
|
411
514
|
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
515
|
+
wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
|
412
516
|
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
413
517
|
wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
|
414
518
|
|
@@ -474,6 +578,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
474
578
|
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
475
579
|
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
476
580
|
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
581
|
+
rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
|
477
582
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
478
583
|
rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
|
479
584
|
|
@@ -500,12 +605,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
500
605
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
501
606
|
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
502
607
|
|
503
|
-
rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
|
608
|
+
rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
|
504
609
|
rbln_attn_impl=rbln_attn_impl,
|
505
610
|
rbln_kvcache_partition_len=rbln_kvcache_partition_len,
|
611
|
+
rbln_kvcache_block_size=rbln_kvcache_block_size,
|
506
612
|
rbln_max_seq_len=rbln_max_seq_len,
|
507
613
|
)
|
508
614
|
|
615
|
+
if rbln_kvcache_block_size is None:
|
616
|
+
if rbln_attn_impl == "eager":
|
617
|
+
rbln_kvcache_block_size = rbln_max_seq_len
|
618
|
+
else:
|
619
|
+
rbln_kvcache_block_size = rbln_kvcache_partition_len
|
620
|
+
|
621
|
+
# FIXME temporal num_blocks
|
622
|
+
rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
|
623
|
+
|
509
624
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
510
625
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
511
626
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
@@ -542,19 +657,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
542
657
|
if query_length > 1:
|
543
658
|
input_info.extend(
|
544
659
|
[
|
545
|
-
("batch_position", [], "int16"),
|
546
660
|
("query_position", [], "int16"),
|
547
661
|
]
|
548
662
|
)
|
549
663
|
|
664
|
+
max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
|
665
|
+
|
666
|
+
if query_length > 1:
|
667
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
668
|
+
else:
|
669
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
670
|
+
|
550
671
|
input_info.extend(
|
551
672
|
[
|
552
673
|
(
|
553
674
|
f"past_key_values_{i}",
|
554
675
|
[
|
555
|
-
|
676
|
+
rbln_kvcache_num_blocks,
|
556
677
|
num_key_value_heads,
|
557
|
-
|
678
|
+
rbln_kvcache_block_size,
|
558
679
|
head_dim,
|
559
680
|
],
|
560
681
|
"float32",
|
@@ -595,7 +716,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
595
716
|
"use_attention_mask": rbln_use_attention_mask,
|
596
717
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
597
718
|
"kvcache_partition_len": rbln_kvcache_partition_len,
|
719
|
+
"kvcache_block_size": rbln_kvcache_block_size,
|
598
720
|
"attn_impl": rbln_attn_impl,
|
721
|
+
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
599
722
|
}
|
600
723
|
)
|
601
724
|
|
@@ -40,10 +40,15 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
40
40
|
new_layers = []
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
|
-
new_self_attn = ExaoneAttention(
|
43
|
+
new_self_attn = ExaoneAttention(
|
44
|
+
layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
45
|
+
)
|
44
46
|
elif self.attn_impl == "flash_attn":
|
45
47
|
new_self_attn = ExaoneFlashAttention(
|
46
|
-
layer.attn.attention,
|
48
|
+
layer.attn.attention,
|
49
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
50
|
+
use_attention_mask=self.use_attention_mask,
|
51
|
+
kvcache_block_size=self.kvcache_block_size,
|
47
52
|
)
|
48
53
|
else:
|
49
54
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -33,10 +33,15 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
33
33
|
new_layers = []
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
|
-
new_self_attn = DecoderOnlyAttention(
|
36
|
+
new_self_attn = DecoderOnlyAttention(
|
37
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
38
|
+
)
|
37
39
|
elif self.attn_impl == "flash_attn":
|
38
40
|
new_self_attn = DecoderOnlyFlashAttention(
|
39
|
-
layer.self_attn,
|
41
|
+
layer.self_attn,
|
42
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
43
|
+
use_attention_mask=self.use_attention_mask,
|
44
|
+
kvcache_block_size=self.kvcache_block_size,
|
40
45
|
)
|
41
46
|
else:
|
42
47
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -37,7 +37,9 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
37
37
|
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
38
38
|
new_layers = []
|
39
39
|
for layer in causal_lm.transformer.h:
|
40
|
-
new_self_attn = GPT2Attention(
|
40
|
+
new_self_attn = GPT2Attention(
|
41
|
+
layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
42
|
+
)
|
41
43
|
new_layer = GPT2Layer(layer, new_self_attn)
|
42
44
|
new_layers.append(new_layer)
|
43
45
|
new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
|
@@ -60,7 +60,9 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
60
60
|
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
61
61
|
new_layers = []
|
62
62
|
for layer in causal_lm.transformer.h:
|
63
|
-
new_self_attn = MidmAttention(
|
63
|
+
new_self_attn = MidmAttention(
|
64
|
+
layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
65
|
+
)
|
64
66
|
new_layer = MidmLayer(layer, new_self_attn)
|
65
67
|
new_layers.append(new_layer)
|
66
68
|
new_model = MidmModel(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
|
@@ -36,7 +36,9 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
36
36
|
new_layers = []
|
37
37
|
for layer in causal_lm.model.layers:
|
38
38
|
if self.attn_impl == "eager":
|
39
|
-
new_self_attn = PhiAttention(
|
39
|
+
new_self_attn = PhiAttention(
|
40
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
41
|
+
)
|
40
42
|
elif self.attn_impl == "flash_attn":
|
41
43
|
raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
|
42
44
|
else:
|
@@ -81,10 +83,10 @@ class PhiLayer(DecoderOnlyLayer):
|
|
81
83
|
hidden_states: torch.Tensor,
|
82
84
|
attention_mask: torch.Tensor,
|
83
85
|
seq_positions: torch.LongTensor,
|
84
|
-
batch_position: torch.Tensor,
|
85
86
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
86
87
|
cos: Optional[torch.Tensor] = None,
|
87
88
|
sin: Optional[torch.Tensor] = None,
|
89
|
+
block_tables: Optional[torch.Tensor] = None,
|
88
90
|
):
|
89
91
|
residual = hidden_states
|
90
92
|
|
@@ -94,10 +96,10 @@ class PhiLayer(DecoderOnlyLayer):
|
|
94
96
|
hidden_states=hidden_states,
|
95
97
|
attention_mask=attention_mask,
|
96
98
|
seq_positions=seq_positions,
|
97
|
-
batch_position=batch_position,
|
98
99
|
past_key_values=past_key_values,
|
99
100
|
cos=cos,
|
100
101
|
sin=sin,
|
102
|
+
block_tables=block_tables,
|
101
103
|
)
|
102
104
|
|
103
105
|
feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
|
@@ -50,11 +50,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
50
50
|
runtime: rebel.Runtime,
|
51
51
|
batch_size: int,
|
52
52
|
dec_max_seq_len: int,
|
53
|
+
use_attention_mask: Optional[bool] = None,
|
53
54
|
**kwargs: Any,
|
54
55
|
) -> None:
|
55
56
|
super().__init__(runtime, **kwargs)
|
56
57
|
self.batch_size = batch_size
|
57
58
|
self.dec_max_seq_len = dec_max_seq_len
|
59
|
+
self.use_attention_mask = use_attention_mask
|
60
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
58
61
|
|
59
62
|
def forward(
|
60
63
|
self,
|
@@ -62,6 +65,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
62
65
|
attention_mask: Optional[torch.FloatTensor] = None,
|
63
66
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
64
67
|
cache_position: Optional[torch.Tensor] = None,
|
68
|
+
block_tables: Optional[torch.Tensor] = None,
|
65
69
|
**kwargs,
|
66
70
|
) -> Tuple[torch.FloatTensor]:
|
67
71
|
batch_size = decoder_input_ids.shape[0]
|
@@ -73,19 +77,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
73
77
|
if batch_size != cache_position.shape[0]:
|
74
78
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
75
79
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
80
|
+
if self.use_attention_mask:
|
81
|
+
for b_idx in range(self.batch_size):
|
82
|
+
decoding_step = cache_position[b_idx].item()
|
83
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
84
|
+
raise ValueError(
|
85
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
86
|
+
)
|
87
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
88
|
+
|
89
|
+
if block_tables is None:
|
90
|
+
block_tables = self.default_block_tables
|
83
91
|
|
84
92
|
lm_logits = super().forward(
|
85
93
|
decoder_input_ids,
|
86
|
-
decoder_attention_mask,
|
94
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
87
95
|
attention_mask,
|
88
96
|
cache_position,
|
97
|
+
block_tables,
|
89
98
|
)
|
90
99
|
|
91
100
|
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -110,12 +119,18 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
110
119
|
def __post_init__(self, **kwargs):
|
111
120
|
batch_size = self.rbln_config.model_cfg["batch_size"]
|
112
121
|
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
122
|
+
self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
|
123
|
+
|
113
124
|
self.encoder = RBLNRuntimeEncoder(
|
114
125
|
runtime=self.model[0],
|
115
126
|
main_input_name="input_ids",
|
116
127
|
)
|
117
128
|
self.decoder = RBLNRuntimeDecoder(
|
118
|
-
runtime=self.model[1],
|
129
|
+
runtime=self.model[1],
|
130
|
+
main_input_name="input_ids",
|
131
|
+
batch_size=batch_size,
|
132
|
+
dec_max_seq_len=dec_max_seq_len,
|
133
|
+
use_attention_mask=self.use_attention_mask,
|
119
134
|
)
|
120
135
|
|
121
136
|
@classmethod
|
@@ -171,6 +186,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
171
186
|
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
172
187
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
173
188
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
189
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
190
|
+
|
191
|
+
if rbln_use_attention_mask is None:
|
192
|
+
rbln_use_attention_mask = False
|
193
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
194
|
+
if rbln_npu == "RBLN-CA02":
|
195
|
+
rbln_use_attention_mask = True
|
174
196
|
|
175
197
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
176
198
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -232,18 +254,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
232
254
|
],
|
233
255
|
"float32",
|
234
256
|
),
|
235
|
-
("
|
257
|
+
("block_tables", [1], "int16"),
|
236
258
|
]
|
237
259
|
|
238
260
|
dec_input_info = [
|
239
261
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
240
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
241
262
|
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
242
263
|
(
|
243
264
|
"cache_position",
|
244
265
|
[rbln_batch_size, 1],
|
245
266
|
"int32",
|
246
267
|
),
|
268
|
+
(
|
269
|
+
"block_tables",
|
270
|
+
[rbln_batch_size, 1],
|
271
|
+
"int16",
|
272
|
+
),
|
247
273
|
]
|
248
274
|
dec_input_info.extend(
|
249
275
|
[
|
@@ -275,6 +301,10 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
275
301
|
for i in range(n_layer * 2)
|
276
302
|
]
|
277
303
|
)
|
304
|
+
|
305
|
+
if rbln_use_attention_mask:
|
306
|
+
dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
|
307
|
+
|
278
308
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
279
309
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
280
310
|
|
@@ -290,6 +320,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
290
320
|
"dec_max_seq_len": rbln_dec_max_seq_len,
|
291
321
|
"batch_size": rbln_batch_size,
|
292
322
|
"pad_token_id": rbln_pad_token_id,
|
323
|
+
"use_attention_mask": rbln_use_attention_mask,
|
293
324
|
}
|
294
325
|
)
|
295
326
|
|
@@ -400,9 +431,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
400
431
|
encoder_kwargs["output_attentions"] = False
|
401
432
|
|
402
433
|
for b in range(batch_size):
|
403
|
-
|
434
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
404
435
|
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
405
436
|
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
406
|
-
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs,
|
437
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
407
438
|
|
408
439
|
return model_kwargs
|