optimum-rbln 0.7.2rc2__py3-none-any.whl → 0.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +8 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +8 -0
- optimum/rbln/diffusers/modeling_diffusers.py +103 -117
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -3
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +15 -8
- optimum/rbln/diffusers/pipelines/__init__.py +8 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +7 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +107 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +25 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +3 -0
- optimum/rbln/modeling.py +4 -1
- optimum/rbln/modeling_base.py +16 -3
- optimum/rbln/ops/__init__.py +6 -2
- optimum/rbln/ops/attn.py +94 -85
- optimum/rbln/ops/flash_attn.py +46 -25
- optimum/rbln/ops/kv_cache_update.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +10 -6
- optimum/rbln/transformers/models/bart/modeling_bart.py +6 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +264 -133
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +276 -29
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +11 -4
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +11 -4
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +5 -3
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -3
- optimum/rbln/transformers/models/phi/phi_architecture.py +9 -7
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +50 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +60 -36
- optimum/rbln/transformers/models/t5/modeling_t5.py +3 -1
- optimum/rbln/transformers/models/t5/t5_architecture.py +65 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +26 -36
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -14
- optimum/rbln/utils/import_utils.py +7 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/METADATA +1 -1
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/RECORD +40 -38
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.2rc2.dist-info → optimum_rbln-0.7.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,9 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
+
import math
|
17
|
+
from collections import deque
|
16
18
|
from dataclasses import dataclass
|
17
19
|
from pathlib import Path
|
18
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
20
|
+
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
|
19
21
|
|
20
22
|
import rebel
|
21
23
|
import torch
|
@@ -50,14 +52,28 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
50
52
|
phase: str,
|
51
53
|
batch_size: int,
|
52
54
|
dec_attn_mask: torch.Tensor,
|
55
|
+
block_tables: torch.Tensor,
|
56
|
+
free_block_pool: Deque,
|
57
|
+
kvcache_block_size: int,
|
58
|
+
use_attention_mask: bool,
|
59
|
+
attn_impl: str,
|
53
60
|
**kwargs: Any,
|
54
61
|
) -> None:
|
55
62
|
super().__init__(runtime, **kwargs)
|
56
63
|
self.phase = phase
|
57
64
|
self.batch_size = batch_size
|
58
65
|
|
66
|
+
# shared data structures between prefill and decode phase
|
67
|
+
self.use_attention_mask = use_attention_mask
|
68
|
+
|
59
69
|
# shared tensor between prefill and decode phase
|
60
70
|
self.dec_attn_mask = dec_attn_mask
|
71
|
+
self.block_tables = block_tables
|
72
|
+
self.free_block_pool = free_block_pool
|
73
|
+
|
74
|
+
self.kvcache_block_size = kvcache_block_size
|
75
|
+
self.empty_block = -1
|
76
|
+
self.attn_impl = attn_impl
|
61
77
|
|
62
78
|
if self.phase == "prefill":
|
63
79
|
vocab_size = kwargs.pop("vocab_size")
|
@@ -68,6 +84,75 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
68
84
|
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
69
85
|
)
|
70
86
|
|
87
|
+
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None):
|
88
|
+
"""
|
89
|
+
Manages and returns the KV cache block tables.
|
90
|
+
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
94
|
+
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
torch.Tensor: Updated block tables.
|
98
|
+
"""
|
99
|
+
|
100
|
+
NO_BLOCKS_ERROR = (
|
101
|
+
"No memory blocks are available for allocation. "
|
102
|
+
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
|
103
|
+
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
|
104
|
+
"Using vllm-rbln should fix this issue and enhance inference performance."
|
105
|
+
)
|
106
|
+
|
107
|
+
def update_block(batch_idx: int, block_idx: int):
|
108
|
+
"""
|
109
|
+
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
110
|
+
"""
|
111
|
+
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
112
|
+
if self.free_block_pool:
|
113
|
+
block = self.free_block_pool.popleft()
|
114
|
+
self.block_tables[batch_idx][block_idx] = block
|
115
|
+
else:
|
116
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
117
|
+
|
118
|
+
def replace_empty_block(block_tables: torch.Tensor):
|
119
|
+
"""
|
120
|
+
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
121
|
+
"""
|
122
|
+
if not torch.any(block_tables == self.empty_block):
|
123
|
+
return block_tables.clone()
|
124
|
+
elif self.free_block_pool:
|
125
|
+
_free_block = self.free_block_pool[0]
|
126
|
+
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
127
|
+
else:
|
128
|
+
raise RuntimeError(NO_BLOCKS_ERROR)
|
129
|
+
|
130
|
+
if self.phase == "prefill":
|
131
|
+
# Track previously used blocks and return them to the free_block_pool and
|
132
|
+
# reset the current batch's block table to empty blocks
|
133
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
134
|
+
self.free_block_pool.extend(prev_blocks)
|
135
|
+
self.block_tables[batch_idx].fill_(self.empty_block)
|
136
|
+
|
137
|
+
# Get the start (s) and end (e) positions from cache_position and
|
138
|
+
# iterate over the cache positions to allocate necessary blocks
|
139
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
140
|
+
for position in range(s, e + 1, self.kvcache_block_size):
|
141
|
+
block_idx = position // self.kvcache_block_size
|
142
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
143
|
+
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
144
|
+
update_block(batch_idx, block_idx)
|
145
|
+
|
146
|
+
return replace_empty_block(self.block_tables[batch_idx])
|
147
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
148
|
+
else:
|
149
|
+
for b_idx in range(self.batch_size):
|
150
|
+
position = cache_position[b_idx][0].item()
|
151
|
+
block_idx = position // self.kvcache_block_size
|
152
|
+
update_block(b_idx, block_idx)
|
153
|
+
|
154
|
+
return replace_empty_block(self.block_tables)
|
155
|
+
|
71
156
|
def forward(
|
72
157
|
self,
|
73
158
|
input_ids: Optional[torch.LongTensor] = None,
|
@@ -75,6 +160,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
75
160
|
cache_position: torch.Tensor = None,
|
76
161
|
attention_mask: Optional[torch.Tensor] = None,
|
77
162
|
batch_idx: Optional[int] = None,
|
163
|
+
block_tables: Optional[torch.Tensor] = None,
|
78
164
|
):
|
79
165
|
if input_ids is None and inputs_embeds is None:
|
80
166
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -86,19 +172,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
86
172
|
else:
|
87
173
|
inputs = inputs_embeds
|
88
174
|
|
175
|
+
if block_tables is None:
|
176
|
+
block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
177
|
+
is_external_block_tables = False
|
178
|
+
else:
|
179
|
+
is_external_block_tables = True
|
180
|
+
|
89
181
|
if self.phase == "decode":
|
90
182
|
return self.decode_forward(
|
91
183
|
inputs,
|
92
184
|
cache_position,
|
185
|
+
block_tables,
|
186
|
+
is_external_block_tables,
|
93
187
|
attention_mask=attention_mask,
|
94
188
|
)
|
95
189
|
else:
|
96
|
-
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
|
190
|
+
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
|
97
191
|
|
98
192
|
def decode_forward(
|
99
193
|
self,
|
100
194
|
inputs: torch.Tensor,
|
101
195
|
cache_position: torch.Tensor = None,
|
196
|
+
block_tables: torch.Tensor = None,
|
197
|
+
is_external_block_tables: bool = None,
|
102
198
|
attention_mask: Optional[torch.Tensor] = None,
|
103
199
|
) -> torch.FloatTensor:
|
104
200
|
batch_size = inputs.shape[0]
|
@@ -110,19 +206,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
110
206
|
if batch_size != cache_position.shape[0]:
|
111
207
|
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
112
208
|
|
113
|
-
if attention_mask is None:
|
209
|
+
if self.use_attention_mask and attention_mask is None:
|
114
210
|
for b_idx in range(batch_size):
|
115
211
|
decoding_step = cache_position[b_idx].item()
|
116
212
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
117
213
|
raise ValueError(
|
118
214
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
119
215
|
)
|
120
|
-
|
216
|
+
|
217
|
+
if is_external_block_tables:
|
218
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
219
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
220
|
+
else:
|
221
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
222
|
+
|
223
|
+
attention_mask = self.dec_attn_mask
|
224
|
+
|
225
|
+
attention_mask = self.dec_attn_mask
|
121
226
|
|
122
227
|
logits = super().forward(
|
123
228
|
inputs,
|
124
|
-
self.dec_attn_mask if attention_mask is None else attention_mask,
|
125
229
|
cache_position,
|
230
|
+
attention_mask if self.use_attention_mask else None,
|
231
|
+
block_tables,
|
126
232
|
)
|
127
233
|
|
128
234
|
return logits
|
@@ -133,6 +239,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
133
239
|
cache_position: torch.Tensor = None,
|
134
240
|
attention_mask: Optional[torch.Tensor] = None,
|
135
241
|
batch_idx: int = None,
|
242
|
+
block_tables: torch.Tensor = None,
|
243
|
+
is_external_block_tables: bool = None,
|
136
244
|
) -> torch.FloatTensor:
|
137
245
|
"""
|
138
246
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -140,11 +248,6 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
140
248
|
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
141
249
|
"""
|
142
250
|
|
143
|
-
if batch_idx is None or batch_idx >= self.batch_size:
|
144
|
-
raise RuntimeError(
|
145
|
-
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
146
|
-
)
|
147
|
-
|
148
251
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
149
252
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
150
253
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
@@ -156,7 +259,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
156
259
|
)
|
157
260
|
|
158
261
|
# Initialize attention mask for chunked processing
|
159
|
-
|
262
|
+
if self.use_attention_mask:
|
263
|
+
chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
160
264
|
|
161
265
|
# Buffer for storing output logits
|
162
266
|
out_buffers = [
|
@@ -195,28 +299,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
195
299
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
196
300
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
197
301
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
302
|
+
if self.use_attention_mask:
|
303
|
+
# Update attention mask to ensure proper causal behavior
|
304
|
+
if step >= self.prefill_chunk_size:
|
305
|
+
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
306
|
+
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
202
307
|
|
203
|
-
# Define
|
204
|
-
batch_position = torch.tensor(batch_idx, dtype=torch.int16)
|
308
|
+
# Define query position
|
205
309
|
query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
|
206
310
|
|
207
311
|
# Forward pass for the current chunk
|
208
312
|
logits = super().forward(
|
209
313
|
input_chunk,
|
210
|
-
chunked_attention_mask,
|
211
314
|
cache_pos_chunk,
|
212
|
-
|
315
|
+
chunked_attention_mask if self.use_attention_mask else None,
|
213
316
|
query_position,
|
317
|
+
block_tables,
|
214
318
|
out=out_buffers,
|
215
319
|
)
|
216
320
|
|
217
321
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
218
|
-
self.
|
219
|
-
|
322
|
+
if not is_external_block_tables and self.use_attention_mask:
|
323
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
324
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
220
325
|
|
221
326
|
return logits
|
222
327
|
|
@@ -256,8 +361,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
256
361
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
257
362
|
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
258
363
|
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
259
|
-
|
364
|
+
self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
|
365
|
+
# FIXME get kvcache_num_blocks from compiled results.
|
366
|
+
self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
|
367
|
+
self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
|
368
|
+
attn_impl = self.rbln_config.model_cfg["attn_impl"]
|
260
369
|
main_input_name = self.main_input_name
|
370
|
+
|
261
371
|
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
262
372
|
main_input_name = "inputs_embeds"
|
263
373
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
@@ -271,7 +381,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
271
381
|
else:
|
272
382
|
self.embed_tokens = None
|
273
383
|
|
384
|
+
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
274
385
|
dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
386
|
+
block_tables = torch.zeros(
|
387
|
+
self.batch_size, self.max_seq_len // self.kvcache_block_size, dtype=torch.int16
|
388
|
+
).fill_(-1)
|
389
|
+
free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
|
390
|
+
|
275
391
|
self.prefill_decoder = RBLNRuntimeModel(
|
276
392
|
runtime=self.model[0],
|
277
393
|
main_input_name=main_input_name,
|
@@ -279,9 +395,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
279
395
|
phase="prefill",
|
280
396
|
batch_size=self.batch_size,
|
281
397
|
dec_attn_mask=dec_attn_mask,
|
398
|
+
block_tables=block_tables,
|
399
|
+
free_block_pool=free_block_pool,
|
400
|
+
kvcache_block_size=self.kvcache_block_size,
|
282
401
|
vocab_size=self.config.vocab_size,
|
283
|
-
max_seq_len=self.max_seq_len,
|
284
402
|
prefill_chunk_size=self.prefill_chunk_size,
|
403
|
+
max_seq_len=self.max_seq_len,
|
404
|
+
use_attention_mask=self.use_attention_mask,
|
405
|
+
attn_impl=attn_impl,
|
285
406
|
)
|
286
407
|
self.decoder = RBLNRuntimeModel(
|
287
408
|
runtime=self.model[1],
|
@@ -290,6 +411,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
290
411
|
phase="decode",
|
291
412
|
batch_size=self.batch_size,
|
292
413
|
dec_attn_mask=dec_attn_mask,
|
414
|
+
block_tables=block_tables,
|
415
|
+
free_block_pool=free_block_pool,
|
416
|
+
kvcache_block_size=self.kvcache_block_size,
|
417
|
+
use_attention_mask=self.use_attention_mask,
|
418
|
+
attn_impl=attn_impl,
|
293
419
|
)
|
294
420
|
|
295
421
|
@classmethod
|
@@ -363,7 +489,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
363
489
|
def redirect(func):
|
364
490
|
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
365
491
|
|
366
|
-
val = getattr(self.
|
492
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
367
493
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
368
494
|
return redirect(val)
|
369
495
|
return val
|
@@ -387,7 +513,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
387
513
|
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
388
514
|
wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
|
389
515
|
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
516
|
+
wrapper_cfg["kvcache_block_size"] = rbln_config.model_cfg.get("kvcache_block_size")
|
390
517
|
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
518
|
+
wrapper_cfg["use_attention_mask"] = rbln_config.model_cfg.get("use_attention_mask")
|
391
519
|
|
392
520
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
393
521
|
|
@@ -438,6 +566,71 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
438
566
|
|
439
567
|
return compile_model(quantize_config=quantize_config)
|
440
568
|
|
569
|
+
@classmethod
|
570
|
+
def get_maximum_num_blocks(
|
571
|
+
cls,
|
572
|
+
config: PretrainedConfig,
|
573
|
+
tensor_parallel_size: int,
|
574
|
+
kvcache_block_size: int,
|
575
|
+
nbits_per_param: int,
|
576
|
+
n_model_params: int,
|
577
|
+
) -> int:
|
578
|
+
def align(x: int, nbytes: int) -> int:
|
579
|
+
return int(math.ceil(x / nbytes) * nbytes)
|
580
|
+
|
581
|
+
def align_2MB(x: int) -> int:
|
582
|
+
return align(x, 2 * 1024 * 1024)
|
583
|
+
|
584
|
+
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
585
|
+
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
586
|
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
587
|
+
vocab_size = config.vocab_size
|
588
|
+
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
589
|
+
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
590
|
+
|
591
|
+
# TODO(jongho): Update if target npu is REBEL.
|
592
|
+
ATOM_DRAM_NBYTES = 16 * 2**30
|
593
|
+
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
594
|
+
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
595
|
+
|
596
|
+
# Get estimated kernel size (approximated)
|
597
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
598
|
+
lm_heads_nbytes = (
|
599
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
600
|
+
)
|
601
|
+
params = n_model_params - lm_heads_params
|
602
|
+
layer_nbytes = (
|
603
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
604
|
+
* num_layers
|
605
|
+
* tensor_parallel_size
|
606
|
+
)
|
607
|
+
kernel_size = layer_nbytes + lm_heads_nbytes
|
608
|
+
|
609
|
+
available_dram -= kernel_size
|
610
|
+
|
611
|
+
# TODO: Accurate buffer estimation
|
612
|
+
buffer = 2**30 # 1GB Buffer
|
613
|
+
if tensor_parallel_size <= 4:
|
614
|
+
buffer /= 4
|
615
|
+
|
616
|
+
available_dram -= buffer
|
617
|
+
|
618
|
+
# Estimate nbytes per a single kvcache block
|
619
|
+
nbytes_per_block = (
|
620
|
+
align_2MB(
|
621
|
+
kvcache_block_size
|
622
|
+
* head_dim
|
623
|
+
* math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
|
624
|
+
* 2 # (fp16)
|
625
|
+
)
|
626
|
+
* num_layers
|
627
|
+
* 2 # (k, v)
|
628
|
+
* tensor_parallel_size
|
629
|
+
)
|
630
|
+
n_blocks = available_dram // nbytes_per_block
|
631
|
+
|
632
|
+
return n_blocks, nbytes_per_block
|
633
|
+
|
441
634
|
@classmethod
|
442
635
|
def _get_rbln_config(
|
443
636
|
cls,
|
@@ -448,11 +641,19 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
448
641
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
449
642
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
450
643
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
644
|
+
rbln_use_attention_mask = rbln_kwargs.get("use_attention_mask", None)
|
451
645
|
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
452
646
|
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
647
|
+
rbln_kvcache_block_size = rbln_kwargs.get("kvcache_block_size", None)
|
453
648
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
454
649
|
rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
|
455
650
|
|
651
|
+
if rbln_use_attention_mask is None:
|
652
|
+
rbln_use_attention_mask = False
|
653
|
+
rbln_npu = rbln_kwargs.get("npu", None) or rebel.get_npu_name()
|
654
|
+
if rbln_npu == "RBLN-CA02":
|
655
|
+
rbln_use_attention_mask = True
|
656
|
+
|
456
657
|
if rbln_prefill_chunk_size is None:
|
457
658
|
rbln_prefill_chunk_size = 128
|
458
659
|
elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
|
@@ -470,12 +671,42 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
470
671
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
471
672
|
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
472
673
|
|
473
|
-
rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
|
674
|
+
rbln_attn_impl, rbln_kvcache_partition_len, rbln_kvcache_block_size = validate_attention_method(
|
474
675
|
rbln_attn_impl=rbln_attn_impl,
|
475
676
|
rbln_kvcache_partition_len=rbln_kvcache_partition_len,
|
677
|
+
rbln_kvcache_block_size=rbln_kvcache_block_size,
|
476
678
|
rbln_max_seq_len=rbln_max_seq_len,
|
477
679
|
)
|
478
680
|
|
681
|
+
if rbln_kvcache_block_size is None:
|
682
|
+
if rbln_attn_impl == "eager":
|
683
|
+
rbln_kvcache_block_size = rbln_max_seq_len
|
684
|
+
else:
|
685
|
+
rbln_kvcache_block_size = rbln_kvcache_partition_len
|
686
|
+
|
687
|
+
rbln_kvcache_num_blocks = (rbln_max_seq_len // rbln_kvcache_block_size) * rbln_batch_size
|
688
|
+
if rbln_attn_impl == "flash_attn":
|
689
|
+
max_num_blocks, _ = cls.get_maximum_num_blocks(
|
690
|
+
config=model_config,
|
691
|
+
tensor_parallel_size=rbln_kwargs.get("tensor_parallel_size", 1),
|
692
|
+
kvcache_block_size=rbln_kvcache_block_size,
|
693
|
+
nbits_per_param=16 if rbln_quantization is None else 4, # TODO(jongho): FIX Ad-hoc
|
694
|
+
n_model_params=rbln_kwargs["n_model_params"],
|
695
|
+
)
|
696
|
+
rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
|
697
|
+
|
698
|
+
required_blocks = rbln_max_seq_len // rbln_kvcache_block_size + 1
|
699
|
+
if rbln_kvcache_num_blocks < required_blocks:
|
700
|
+
rbln_kvcache_num_blocks = required_blocks
|
701
|
+
|
702
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_kvcache_num_blocks}")
|
703
|
+
|
704
|
+
if rbln_kvcache_num_blocks < rbln_batch_size:
|
705
|
+
raise RuntimeError(
|
706
|
+
f"Batch size ({rbln_batch_size}) exceeds available KV cache blocks ({rbln_kvcache_num_blocks}). "
|
707
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
708
|
+
)
|
709
|
+
|
479
710
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
480
711
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
481
712
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
@@ -495,29 +726,42 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
495
726
|
|
496
727
|
input_info = [
|
497
728
|
main_input,
|
498
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
499
729
|
(
|
500
730
|
"cache_position",
|
501
731
|
[batch_size, query_length],
|
502
732
|
"int32",
|
503
733
|
),
|
504
734
|
]
|
735
|
+
|
736
|
+
if rbln_use_attention_mask:
|
737
|
+
input_info.extend(
|
738
|
+
[
|
739
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
740
|
+
]
|
741
|
+
)
|
742
|
+
|
505
743
|
if query_length > 1:
|
506
744
|
input_info.extend(
|
507
745
|
[
|
508
|
-
("batch_position", [], "int16"),
|
509
746
|
("query_position", [], "int16"),
|
510
747
|
]
|
511
748
|
)
|
512
749
|
|
750
|
+
max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
|
751
|
+
|
752
|
+
if query_length > 1:
|
753
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
754
|
+
else:
|
755
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
756
|
+
|
513
757
|
input_info.extend(
|
514
758
|
[
|
515
759
|
(
|
516
760
|
f"past_key_values_{i}",
|
517
761
|
[
|
518
|
-
|
762
|
+
rbln_kvcache_num_blocks,
|
519
763
|
num_key_value_heads,
|
520
|
-
|
764
|
+
rbln_kvcache_block_size,
|
521
765
|
head_dim,
|
522
766
|
],
|
523
767
|
"float32",
|
@@ -555,9 +799,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
555
799
|
"max_seq_len": rbln_max_seq_len,
|
556
800
|
"batch_size": rbln_batch_size,
|
557
801
|
"prefill_chunk_size": rbln_prefill_chunk_size,
|
802
|
+
"use_attention_mask": rbln_use_attention_mask,
|
558
803
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
559
804
|
"kvcache_partition_len": rbln_kvcache_partition_len,
|
805
|
+
"kvcache_block_size": rbln_kvcache_block_size,
|
560
806
|
"attn_impl": rbln_attn_impl,
|
807
|
+
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
561
808
|
}
|
562
809
|
)
|
563
810
|
|
@@ -36,21 +36,28 @@ logger = logging.get_logger(__name__)
|
|
36
36
|
class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
|
37
37
|
"""A wrapper class for the Exaone model with a language modeling head."""
|
38
38
|
|
39
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM"):
|
39
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
|
40
40
|
new_layers = []
|
41
41
|
for layer in causal_lm.transformer.h:
|
42
42
|
if self.attn_impl == "eager":
|
43
|
-
new_self_attn = ExaoneAttention(
|
43
|
+
new_self_attn = ExaoneAttention(
|
44
|
+
layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
45
|
+
)
|
44
46
|
elif self.attn_impl == "flash_attn":
|
45
47
|
new_self_attn = ExaoneFlashAttention(
|
46
|
-
layer.attn.attention,
|
48
|
+
layer.attn.attention,
|
49
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
50
|
+
use_attention_mask=self.use_attention_mask,
|
51
|
+
kvcache_block_size=self.kvcache_block_size,
|
47
52
|
)
|
48
53
|
else:
|
49
54
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
50
55
|
|
51
56
|
new_layer = ExaoneLayer(layer, new_self_attn)
|
52
57
|
new_layers.append(new_layer)
|
53
|
-
new_model = ExaoneModel(
|
58
|
+
new_model = ExaoneModel(
|
59
|
+
causal_lm.transformer, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
|
60
|
+
)
|
54
61
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
55
62
|
return new_causal_lm
|
56
63
|
|
@@ -29,20 +29,27 @@ if TYPE_CHECKING:
|
|
29
29
|
|
30
30
|
|
31
31
|
class GemmaWrapper(DecoderOnlyWrapper):
|
32
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM"):
|
32
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM", max_seq_len: int):
|
33
33
|
new_layers = []
|
34
34
|
for layer in causal_lm.model.layers:
|
35
35
|
if self.attn_impl == "eager":
|
36
|
-
new_self_attn = DecoderOnlyAttention(
|
36
|
+
new_self_attn = DecoderOnlyAttention(
|
37
|
+
layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
38
|
+
)
|
37
39
|
elif self.attn_impl == "flash_attn":
|
38
40
|
new_self_attn = DecoderOnlyFlashAttention(
|
39
|
-
layer.self_attn,
|
41
|
+
layer.self_attn,
|
42
|
+
kvcache_partition_len=self.kvcache_partition_len,
|
43
|
+
use_attention_mask=self.use_attention_mask,
|
44
|
+
kvcache_block_size=self.kvcache_block_size,
|
40
45
|
)
|
41
46
|
else:
|
42
47
|
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
43
48
|
new_layer = DecoderOnlyLayer(layer, new_self_attn)
|
44
49
|
new_layers.append(new_layer)
|
45
|
-
new_model = GemmaModel(
|
50
|
+
new_model = GemmaModel(
|
51
|
+
causal_lm.model, new_layers, partition_len=self.kvcache_partition_len, max_seq_len=max_seq_len
|
52
|
+
)
|
46
53
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
47
54
|
return new_causal_lm
|
48
55
|
|
@@ -32,15 +32,17 @@ if TYPE_CHECKING:
|
|
32
32
|
|
33
33
|
|
34
34
|
class GPT2Wrapper(DecoderOnlyWrapper):
|
35
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
|
35
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel", max_seq_len: int):
|
36
36
|
if self.attn_impl != "eager":
|
37
37
|
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
38
38
|
new_layers = []
|
39
39
|
for layer in causal_lm.transformer.h:
|
40
|
-
new_self_attn = GPT2Attention(
|
40
|
+
new_self_attn = GPT2Attention(
|
41
|
+
layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
42
|
+
)
|
41
43
|
new_layer = GPT2Layer(layer, new_self_attn)
|
42
44
|
new_layers.append(new_layer)
|
43
|
-
new_model = GPT2Model(causal_lm.transformer, new_layers)
|
45
|
+
new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
|
44
46
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
45
47
|
return new_causal_lm
|
46
48
|
|
@@ -55,15 +55,17 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
|
|
55
55
|
self.config.partial_rotary_factor = self.config.rotary_percentage
|
56
56
|
return super().get_rotary_emb(max_seq_len=max_seq_len)
|
57
57
|
|
58
|
-
def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel"):
|
58
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "MidmLMHeadModel", max_seq_len: int):
|
59
59
|
if self.attn_impl != "eager":
|
60
60
|
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
61
61
|
new_layers = []
|
62
62
|
for layer in causal_lm.transformer.h:
|
63
|
-
new_self_attn = MidmAttention(
|
63
|
+
new_self_attn = MidmAttention(
|
64
|
+
layer.attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
|
65
|
+
)
|
64
66
|
new_layer = MidmLayer(layer, new_self_attn)
|
65
67
|
new_layers.append(new_layer)
|
66
|
-
new_model = MidmModel(causal_lm.transformer, new_layers)
|
68
|
+
new_model = MidmModel(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
|
67
69
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
68
70
|
return new_causal_lm
|
69
71
|
|