optimum-rbln 0.7.3a1__py3-none-any.whl → 0.7.3a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/ops/__init__.py +4 -4
- optimum/rbln/ops/attn.py +44 -84
- optimum/rbln/ops/flash_attn.py +25 -25
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +79 -51
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +161 -34
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +7 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +7 -2
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +3 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +3 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +5 -3
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +44 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +50 -19
- optimum/rbln/transformers/models/t5/modeling_t5.py +211 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +69 -3
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a2.dist-info}/RECORD +21 -21
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3a1.dist-info → optimum_rbln-0.7.3a2.dist-info}/licenses/LICENSE +0 -0
@@ -13,9 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
+
from collections import deque
|
16
17
|
from dataclasses import dataclass
|
17
18
|
from pathlib import Path
|
18
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
19
|
+
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import rebel
|
21
22
|
import torch
|
@@ -50,17 +51,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
50
51
|
phase: str,
|
51
52
|
batch_size: int,
|
52
53
|
dec_attn_mask: torch.Tensor,
|
54
|
+
block_tables: torch.Tensor,
|
55
|
+
free_block_pool: Deque,
|
56
|
+
kvcache_block_size: int,
|
57
|
+
kvcache_num_blocks: int,
|
53
58
|
use_attention_mask: bool,
|
59
|
+
attn_impl: str,
|
54
60
|
**kwargs: Any,
|
55
61
|
) -> None:
|
56
62
|
super().__init__(runtime, **kwargs)
|
57
63
|
self.phase = phase
|
58
64
|
self.batch_size = batch_size
|
59
65
|
|
66
|
+
# shared data structures between prefill and decode phase
|
60
67
|
self.use_attention_mask = use_attention_mask
|
61
68
|
|
62
69
|
# shared tensor between prefill and decode phase
|
63
70
|
self.dec_attn_mask = dec_attn_mask
|
71
|
+
self.block_tables = block_tables
|
72
|
+
self.free_block_pool = free_block_pool
|
73
|
+
|
74
|
+
self.kvcache_block_size = kvcache_block_size
|
75
|
+
self.empty_block = kvcache_num_blocks - 1
|
76
|
+
self.attn_impl = attn_impl
|
64
77
|
|
65
78
|
if self.phase == "prefill":
|
66
79
|
vocab_size = kwargs.pop("vocab_size")
|
@@ -71,6 +84,72 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
71
84
|
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
72
85
|
)
|
73
86
|
|
87
|
+
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
|
88
|
+
"""
|
89
|
+
Manages and returns the KV cache block tables.
|
90
|
+
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
94
|
+
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
torch.Tensor: Updated block tables.
|
98
|
+
"""
|
99
|
+
|
100
|
+
def update_block(batch_idx, block_idx):
|
101
|
+
"""
|
102
|
+
Helper function to update the block table for a given batch index and block index.
|
103
|
+
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
batch_idx (int): Batch index.
|
107
|
+
block_idx (int): Block index.
|
108
|
+
|
109
|
+
Raises:
|
110
|
+
RuntimeError: Raised if no available blocks are found in the free_block_pool.
|
111
|
+
"""
|
112
|
+
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
113
|
+
if self.free_block_pool:
|
114
|
+
block = self.free_block_pool.popleft()
|
115
|
+
self.block_tables[batch_idx][block_idx] = block
|
116
|
+
else:
|
117
|
+
raise RuntimeError("Not available blocks")
|
118
|
+
|
119
|
+
if self.attn_impl == "eager":
|
120
|
+
if self.phase == "prefill":
|
121
|
+
return self.block_tables[batch_idx]
|
122
|
+
else:
|
123
|
+
return self.block_tables
|
124
|
+
# Case for 'flash_attn' attention implementation
|
125
|
+
else:
|
126
|
+
if self.phase == "prefill":
|
127
|
+
# Track previously used blocks and return them to the free_block_pool and
|
128
|
+
# reset the current batch's block table to empty blocks
|
129
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
130
|
+
self.free_block_pool.extend(prev_blocks)
|
131
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
132
|
+
|
133
|
+
# Get the start (s) and end (e) positions from cache_position and
|
134
|
+
# iterate over the cache positions to allocate necessary blocks
|
135
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
136
|
+
for position in range(s, e + 1, self.kvcache_block_size):
|
137
|
+
block_idx = position // self.kvcache_block_size
|
138
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
139
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
140
|
+
update_block(batch_idx, block_idx)
|
141
|
+
|
142
|
+
return self.block_tables[batch_idx]
|
143
|
+
|
144
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
145
|
+
else:
|
146
|
+
for b_idx in range(self.batch_size):
|
147
|
+
position = cache_position[b_idx][0].item()
|
148
|
+
block_idx = position // self.kvcache_block_size
|
149
|
+
update_block(b_idx, block_idx)
|
150
|
+
|
151
|
+
return self.block_tables
|
152
|
+
|
74
153
|
def forward(
|
75
154
|
self,
|
76
155
|
input_ids: Optional[torch.LongTensor] = None,
|
@@ -78,6 +157,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
78
157
|
cache_position: torch.Tensor = None,
|
79
158
|
attention_mask: Optional[torch.Tensor] = None,
|
80
159
|
batch_idx: Optional[int] = None,
|
160
|
+
block_tables: Optional[torch.Tensor] = None,
|
81
161
|
):
|
82
162
|
if input_ids is None and inputs_embeds is None:
|
83
163
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -89,19 +169,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
89
169
|
else:
|
90
170
|
inputs = inputs_embeds
|
91
171
|
|
172
|
+
if block_tables is None:
|
173
|
+
block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
174
|
+
is_external_block_tables = False
|
175
|
+
else:
|
176
|
+
is_external_block_tables = True
|
177
|
+
|
92
178
|
if self.phase == "decode":
|
93
179
|
return self.decode_forward(
|
94
180
|
inputs,
|
95
181
|
cache_position,
|
182
|
+
block_tables,
|
183
|
+
is_external_block_tables,
|
96
184
|
attention_mask=attention_mask,
|
97
185
|
)
|
98
186
|
else:
|
99
|
-
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
|
187
|
+
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
|
100
188
|
|
101
189
|
def decode_forward(
|
102
190
|
self,
|
103
191
|
inputs: torch.Tensor,
|
104
192
|
cache_position: torch.Tensor = None,
|
193
|
+
block_tables: torch.Tensor = None,
|
194
|
+
is_external_block_tables: bool = None,
|
105
195
|
attention_mask: Optional[torch.Tensor] = None,
|
106
196
|
) -> torch.FloatTensor:
|
107
197
|
batch_size = inputs.shape[0]
|
@@ -120,7 +210,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
120
210
|
raise ValueError(
|
121
211
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
122
212
|
)
|
123
|
-
|
213
|
+
|
214
|
+
if is_external_block_tables:
|
215
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
216
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
217
|
+
else:
|
218
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
219
|
+
|
220
|
+
attention_mask = self.dec_attn_mask
|
124
221
|
|
125
222
|
attention_mask = self.dec_attn_mask
|
126
223
|
|
@@ -128,6 +225,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
128
225
|
inputs,
|
129
226
|
cache_position,
|
130
227
|
attention_mask if self.use_attention_mask else None,
|
228
|
+
block_tables,
|
131
229
|
)
|
132
230
|
|
133
231
|
return logits
|
@@ -138,6 +236,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
138
236
|
cache_position: torch.Tensor = None,
|
139
237
|
attention_mask: Optional[torch.Tensor] = None,
|
140
238
|
batch_idx: int = None,
|
239
|
+
block_tables: torch.Tensor = None,
|
240
|
+
is_external_block_tables: bool = None,
|
141
241
|
) -> torch.FloatTensor:
|
142
242
|
"""
|
143
243
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -145,11 +245,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
145
245
|
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
146
246
|
"""
|
147
247
|
|
148
|
-
if batch_idx is None or batch_idx >= self.batch_size:
|
149
|
-
raise RuntimeError(
|
150
|
-
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
151
|
-
)
|
152
|
-
|
153
248
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
154
249
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
155
250
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
@@ -207,33 +302,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
207
302
|
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
208
303
|
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
209
304
|
|
210
|
-
# Define
|
211
|
-
batch_position = torch.tensor(batch_idx, dtype=torch.int16)
|
305
|
+
# Define query position
|
212
306
|
query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
|
213
307
|
|
214
|
-
if self.use_attention_mask:
|
215
|
-
args = (
|
216
|
-
input_chunk,
|
217
|
-
cache_pos_chunk,
|
218
|
-
chunked_attention_mask,
|
219
|
-
batch_position,
|
220
|
-
query_position,
|
221
|
-
)
|
222
|
-
else:
|
223
|
-
args = (
|
224
|
-
input_chunk,
|
225
|
-
cache_pos_chunk,
|
226
|
-
batch_position,
|
227
|
-
query_position,
|
228
|
-
)
|
229
308
|
# Forward pass for the current chunk
|
230
309
|
logits = super().forward(
|
231
|
-
|
310
|
+
input_chunk,
|
311
|
+
cache_pos_chunk,
|
312
|
+
chunked_attention_mask if self.use_attention_mask else None,
|
313
|
+
query_position,
|
314
|
+
block_tables,
|
232
315
|
out=out_buffers,
|
233
316
|
)
|
234
317
|
|
235
|
-
|
236
|
-
|
318
|
+
# Update decoder attention mask with processed KV-cache length from prefill phase
|
319
|
+
if not is_external_block_tables and self.use_attention_mask:
|
237
320
|
self.dec_attn_mask[batch_idx].fill_(0)
|
238
321
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
239
322
|
|
@@ -275,9 +358,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
275
358
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
276
359
|
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
277
360
|
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
361
|
+
self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
|
362
|
+
# FIXME get kvcache_num_blocks from compiled results.
|
363
|
+
self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
|
278
364
|
self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
|
279
|
-
|
365
|
+
attn_impl = self.rbln_config.model_cfg["attn_impl"]
|
280
366
|
main_input_name = self.main_input_name
|
367
|
+
|
281
368
|
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
282
369
|
main_input_name = "inputs_embeds"
|
283
370
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
@@ -291,7 +378,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
291
378
|
else:
|
292
379
|
self.embed_tokens = None
|
293
380
|
|
381
|
+
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
294
382
|
dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
383
|
+
if attn_impl == "eager":
|
384
|
+
block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).reshape(self.batch_size, 1)
|
385
|
+
free_block_pool = None
|
386
|
+
else:
|
387
|
+
block_tables = torch.zeros(
|
388
|
+
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
389
|
+
).fill_(self.kvcache_num_blocks - 1)
|
390
|
+
free_block_pool = deque(x for x in range(self.kvcache_num_blocks - 1))
|
391
|
+
|
295
392
|
self.prefill_decoder = RBLNRuntimeModel(
|
296
393
|
runtime=self.model[0],
|
297
394
|
main_input_name=main_input_name,
|
@@ -299,10 +396,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
299
396
|
phase="prefill",
|
300
397
|
batch_size=self.batch_size,
|
301
398
|
dec_attn_mask=dec_attn_mask,
|
399
|
+
block_tables=block_tables,
|
400
|
+
free_block_pool=free_block_pool,
|
401
|
+
kvcache_block_size=self.kvcache_block_size,
|
402
|
+
kvcache_num_blocks=self.kvcache_num_blocks,
|
302
403
|
vocab_size=self.config.vocab_size,
|
303
|
-
max_seq_len=self.max_seq_len,
|
304
404
|
prefill_chunk_size=self.prefill_chunk_size,
|
405
|
+
max_seq_len=self.max_seq_len,
|
305
406
|
use_attention_mask=self.use_attention_mask,
|
407
|
+
attn_impl=attn_impl,
|
306
408
|
)
|
307
409
|
self.decoder = RBLNRuntimeModel(
|
308
410
|
runtime=self.model[1],
|
@@ -311,7 +413,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
311
413
|
phase="decode",
|
312
414
|
batch_size=self.batch_size,
|
313
415
|
dec_attn_mask=dec_attn_mask,
|
416
|
+
block_tables=block_tables,
|
417
|
+
free_block_pool=free_block_pool,
|
418
|
+
kvcache_block_size=self.kvcache_block_size,
|
419
|
+
kvcache_num_blocks=self.kvcache_num_blocks,
|
314
420
|
use_attention_mask=self.use_attention_mask,
|
421
|
+
attn_impl=attn_impl,
|
315
422
|
)
|
316
423
|
|
317
424
|
@classmethod
|
@@ -409,6 +516,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
409
516
|
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
410
517
|
wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
|
411
518
|
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
519
|
+
wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
|
412
520
|
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
413
521
|
wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
|
414
522
|
|
@@ -474,6 +582,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
474
582
|
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
475
583
|
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
476
584
|
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
585
|
+
rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
|
477
586
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
478
587
|
rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
|
479
588
|
|
@@ -500,12 +609,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
500
609
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
501
610
|
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
502
611
|
|
503
|
-
rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
|
612
|
+
rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
|
504
613
|
rbln_attn_impl=rbln_attn_impl,
|
505
614
|
rbln_kvcache_partition_len=rbln_kvcache_partition_len,
|
615
|
+
rbln_kvcache_block_size=rbln_kvcache_block_size,
|
506
616
|
rbln_max_seq_len=rbln_max_seq_len,
|
507
617
|
)
|
508
618
|
|
619
|
+
if rbln_kvcache_block_size is None:
|
620
|
+
if rbln_attn_impl == "eager":
|
621
|
+
rbln_kvcache_block_size = rbln_max_seq_len
|
622
|
+
else:
|
623
|
+
rbln_kvcache_block_size = rbln_kvcache_partition_len
|
624
|
+
|
625
|
+
# FIXME temporal num_blocks
|
626
|
+
rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
|
627
|
+
|
509
628
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
510
629
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
511
630
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
@@ -542,19 +661,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
542
661
|
if query_length > 1:
|
543
662
|
input_info.extend(
|
544
663
|
[
|
545
|
-
("batch_position", [], "int16"),
|
546
664
|
("query_position", [], "int16"),
|
547
665
|
]
|
548
666
|
)
|
549
667
|
|
668
|
+
max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
|
669
|
+
|
670
|
+
if query_length > 1:
|
671
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
672
|
+
else:
|
673
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
674
|
+
|
550
675
|
input_info.extend(
|
551
676
|
[
|
552
677
|
(
|
553
678
|
f"past_key_values_{i}",
|
554
679
|
[
|
555
|
-
|
680
|
+
rbln_kvcache_num_blocks,
|
556
681
|
num_key_value_heads,
|
557
|
-
|
682
|
+
rbln_kvcache_block_size,
|
558
683
|
head_dim,
|
559
684
|
],
|
560
685
|
"float32",
|
@@ -595,7 +720,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
595
720
|
"use_attention_mask": rbln_use_attention_mask,
|
596
721
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
597
722
|
"kvcache_partition_len": rbln_kvcache_partition_len,
|
723
|
+
"kvcache_block_size": rbln_kvcache_block_size,
|
598
724
|
"attn_impl": rbln_attn_impl,
|
725
|
+
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
599
726
|
}
|
600
727
|
)
|
601
728
|
|
@@ -40,10 +40,15 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
|
40
40
|
new_layers = []
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
|
-
new_self_attn = ExaoneAttention(
|
43
|
+
new_self_attn = ExaoneAttention(
|
44
|
+
layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
45
|
+
)
|
44
46
|
elif self.attn_impl == "flash_attn":
|
45
47
|
new_self_attn = ExaoneFlashAttention(
|
46
|
-
layer.attn.attention,
|
48
|
+
layer.attn.attention,
|
49
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
50
|
+
use_attention_mask=self.use_attention_mask,
|
51
|
+
kvcache_block_size=self.kvcache_block_size,
|
47
52
|
)
|
48
53
|
else:
|
49
54
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -33,10 +33,15 @@ class GemmaWrapper(DecoderOnlyWrapper):
|
|
33
33
|
new_layers = []
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
|
-
new_self_attn = DecoderOnlyAttention(
|
36
|
+
new_self_attn = DecoderOnlyAttention(
|
37
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
38
|
+
)
|
37
39
|
elif self.attn_impl == "flash_attn":
|
38
40
|
new_self_attn = DecoderOnlyFlashAttention(
|
39
|
-
layer.self_attn,
|
41
|
+
layer.self_attn,
|
42
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
43
|
+
use_attention_mask=self.use_attention_mask,
|
44
|
+
kvcache_block_size=self.kvcache_block_size,
|
40
45
|
)
|
41
46
|
else:
|
42
47
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
@@ -37,7 +37,9 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
37
37
|
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
38
38
|
new_layers = []
|
39
39
|
for layer in causal_lm.transformer.h:
|
40
|
-
new_self_attn = GPT2Attention(
|
40
|
+
new_self_attn = GPT2Attention(
|
41
|
+
layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
42
|
+
)
|
41
43
|
new_layer = GPT2Layer(layer, new_self_attn)
|
42
44
|
new_layers.append(new_layer)
|
43
45
|
new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
|
@@ -60,7 +60,9 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
60
60
|
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
61
61
|
new_layers = []
|
62
62
|
for layer in causal_lm.transformer.h:
|
63
|
-
new_self_attn = MidmAttention(
|
63
|
+
new_self_attn = MidmAttention(
|
64
|
+
layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
65
|
+
)
|
64
66
|
new_layer = MidmLayer(layer, new_self_attn)
|
65
67
|
new_layers.append(new_layer)
|
66
68
|
new_model = MidmModel(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
|
@@ -36,7 +36,9 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
36
36
|
new_layers = []
|
37
37
|
for layer in causal_lm.model.layers:
|
38
38
|
if self.attn_impl == "eager":
|
39
|
-
new_self_attn = PhiAttention(
|
39
|
+
new_self_attn = PhiAttention(
|
40
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
41
|
+
)
|
40
42
|
elif self.attn_impl == "flash_attn":
|
41
43
|
raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
|
42
44
|
else:
|
@@ -81,10 +83,10 @@ class PhiLayer(DecoderOnlyLayer):
|
|
81
83
|
hidden_states: torch.Tensor,
|
82
84
|
attention_mask: torch.Tensor,
|
83
85
|
seq_positions: torch.LongTensor,
|
84
|
-
batch_position: torch.Tensor,
|
85
86
|
past_key_values: Tuple[Tuple[torch.Tensor]],
|
86
87
|
cos: Optional[torch.Tensor] = None,
|
87
88
|
sin: Optional[torch.Tensor] = None,
|
89
|
+
block_tables: Optional[torch.Tensor] = None,
|
88
90
|
):
|
89
91
|
residual = hidden_states
|
90
92
|
|
@@ -94,10 +96,10 @@ class PhiLayer(DecoderOnlyLayer):
|
|
94
96
|
hidden_states=hidden_states,
|
95
97
|
attention_mask=attention_mask,
|
96
98
|
seq_positions=seq_positions,
|
97
|
-
batch_position=batch_position,
|
98
99
|
past_key_values=past_key_values,
|
99
100
|
cos=cos,
|
100
101
|
sin=sin,
|
102
|
+
block_tables=block_tables,
|
101
103
|
)
|
102
104
|
|
103
105
|
feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
|
@@ -50,11 +50,14 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
50
50
|
runtime: rebel.Runtime,
|
51
51
|
batch_size: int,
|
52
52
|
dec_max_seq_len: int,
|
53
|
+
use_attention_mask: Optional[bool] = None,
|
53
54
|
**kwargs: Any,
|
54
55
|
) -> None:
|
55
56
|
super().__init__(runtime, **kwargs)
|
56
57
|
self.batch_size = batch_size
|
57
58
|
self.dec_max_seq_len = dec_max_seq_len
|
59
|
+
self.use_attention_mask = use_attention_mask
|
60
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
58
61
|
|
59
62
|
def forward(
|
60
63
|
self,
|
@@ -62,6 +65,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
62
65
|
attention_mask: Optional[torch.FloatTensor] = None,
|
63
66
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
64
67
|
cache_position: Optional[torch.Tensor] = None,
|
68
|
+
block_tables: Optional[torch.Tensor] = None,
|
65
69
|
**kwargs,
|
66
70
|
) -> Tuple[torch.FloatTensor]:
|
67
71
|
batch_size = decoder_input_ids.shape[0]
|
@@ -73,19 +77,24 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
73
77
|
if batch_size != cache_position.shape[0]:
|
74
78
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
75
79
|
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
80
|
+
if self.use_attention_mask:
|
81
|
+
for b_idx in range(self.batch_size):
|
82
|
+
decoding_step = cache_position[b_idx].item()
|
83
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
84
|
+
raise ValueError(
|
85
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
86
|
+
)
|
87
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
88
|
+
|
89
|
+
if block_tables is None:
|
90
|
+
block_tables = self.default_block_tables
|
83
91
|
|
84
92
|
lm_logits = super().forward(
|
85
93
|
decoder_input_ids,
|
86
|
-
decoder_attention_mask,
|
94
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
87
95
|
attention_mask,
|
88
96
|
cache_position,
|
97
|
+
block_tables,
|
89
98
|
)
|
90
99
|
|
91
100
|
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -110,12 +119,18 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
110
119
|
def __post_init__(self, **kwargs):
|
111
120
|
batch_size = self.rbln_config.model_cfg["batch_size"]
|
112
121
|
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
122
|
+
self.use_attention_mask = self.rbln_config.model_cfg.get("use_attention_mask", None)
|
123
|
+
|
113
124
|
self.encoder = RBLNRuntimeEncoder(
|
114
125
|
runtime=self.model[0],
|
115
126
|
main_input_name="input_ids",
|
116
127
|
)
|
117
128
|
self.decoder = RBLNRuntimeDecoder(
|
118
|
-
runtime=self.model[1],
|
129
|
+
runtime=self.model[1],
|
130
|
+
main_input_name="input_ids",
|
131
|
+
batch_size=batch_size,
|
132
|
+
dec_max_seq_len=dec_max_seq_len,
|
133
|
+
use_attention_mask=self.use_attention_mask,
|
119
134
|
)
|
120
135
|
|
121
136
|
@classmethod
|
@@ -171,6 +186,13 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
171
186
|
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
172
187
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
173
188
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
189
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
190
|
+
|
191
|
+
if rbln_use_attention_mask is None:
|
192
|
+
rbln_use_attention_mask = False
|
193
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
194
|
+
if rbln_npu == "RBLN-CA02":
|
195
|
+
rbln_use_attention_mask = True
|
174
196
|
|
175
197
|
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
176
198
|
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
@@ -232,18 +254,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
232
254
|
],
|
233
255
|
"float32",
|
234
256
|
),
|
235
|
-
("
|
257
|
+
("block_tables", [1], "int16"),
|
236
258
|
]
|
237
259
|
|
238
260
|
dec_input_info = [
|
239
261
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
240
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
241
262
|
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
242
263
|
(
|
243
264
|
"cache_position",
|
244
265
|
[rbln_batch_size, 1],
|
245
266
|
"int32",
|
246
267
|
),
|
268
|
+
(
|
269
|
+
"block_tables",
|
270
|
+
[rbln_batch_size, 1],
|
271
|
+
"int16",
|
272
|
+
),
|
247
273
|
]
|
248
274
|
dec_input_info.extend(
|
249
275
|
[
|
@@ -275,6 +301,10 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
275
301
|
for i in range(n_layer * 2)
|
276
302
|
]
|
277
303
|
)
|
304
|
+
|
305
|
+
if rbln_use_attention_mask:
|
306
|
+
dec_input_info.insert(1, ("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"))
|
307
|
+
|
278
308
|
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
279
309
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
280
310
|
|
@@ -290,6 +320,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
290
320
|
"dec_max_seq_len": rbln_dec_max_seq_len,
|
291
321
|
"batch_size": rbln_batch_size,
|
292
322
|
"pad_token_id": rbln_pad_token_id,
|
323
|
+
"use_attention_mask": rbln_use_attention_mask,
|
293
324
|
}
|
294
325
|
)
|
295
326
|
|
@@ -400,9 +431,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
400
431
|
encoder_kwargs["output_attentions"] = False
|
401
432
|
|
402
433
|
for b in range(batch_size):
|
403
|
-
|
434
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
404
435
|
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
405
436
|
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
406
|
-
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs,
|
437
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
407
438
|
|
408
439
|
return model_kwargs
|