optimum-rbln 0.9.3__py3-none-any.whl → 0.9.4a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +12 -4
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/modeling_base.py +12 -7
- optimum/rbln/transformers/modeling_attention_utils.py +4 -4
- optimum/rbln/transformers/modeling_outputs.py +1 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +4 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +201 -62
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +106 -36
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +43 -26
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +0 -1
- optimum/rbln/transformers/models/llava/modeling_llava.py +1 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -6
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +4 -4
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
- optimum/rbln/transformers/models/swin/modeling_swin.py +3 -3
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +9 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -2
- optimum/rbln/utils/import_utils.py +7 -1
- optimum/rbln/utils/submodule.py +3 -1
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +52 -52
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from collections import deque
|
|
16
|
-
from typing import Any, Optional
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
17
17
|
|
|
18
18
|
import rebel
|
|
19
19
|
import torch
|
|
@@ -24,6 +24,10 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
|
24
24
|
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers.configuration_utils import PreTrainedConfig
|
|
29
|
+
|
|
30
|
+
|
|
27
31
|
class RBLNPageTableManager:
|
|
28
32
|
EMPTY_BLOCK = -1
|
|
29
33
|
NO_BLOCKS_ERROR = (
|
|
@@ -173,20 +177,23 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
173
177
|
dec_attn_mask: torch.Tensor,
|
|
174
178
|
page_table_manager: RBLNPageTableManager,
|
|
175
179
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
176
|
-
|
|
180
|
+
config: "PreTrainedConfig" = None,
|
|
181
|
+
logits_last_dim: Optional[int] = None,
|
|
177
182
|
**kwargs: Any,
|
|
178
183
|
) -> None:
|
|
179
184
|
super().__init__(runtime, **kwargs)
|
|
180
185
|
self.phase = phase
|
|
181
186
|
self.batch_size = batch_size
|
|
182
187
|
self.rbln_config = rbln_config
|
|
188
|
+
self.config = config
|
|
189
|
+
self.logits_last_dim = logits_last_dim
|
|
183
190
|
|
|
184
191
|
# shared resources between prefill and decode phase
|
|
185
192
|
self.dec_attn_mask = dec_attn_mask
|
|
186
193
|
self.page_table_manager = page_table_manager
|
|
194
|
+
self.out_buffers = None
|
|
187
195
|
|
|
188
196
|
if self.phase == "prefill":
|
|
189
|
-
self.out_buffers = out_buffers
|
|
190
197
|
self.causal_mask = 1 - torch.triu(
|
|
191
198
|
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
192
199
|
)
|
|
@@ -280,28 +287,48 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
280
287
|
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
281
288
|
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
282
289
|
|
|
283
|
-
|
|
290
|
+
batch_size = inputs.shape[0]
|
|
291
|
+
if batch_size != self.batch_size:
|
|
284
292
|
raise RuntimeError(
|
|
285
|
-
f"
|
|
293
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if batch_size != cache_position.shape[0]:
|
|
297
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
298
|
+
|
|
299
|
+
if self.rbln_config.use_local_attention:
|
|
300
|
+
local_block_tables = (
|
|
301
|
+
local_block_tables
|
|
302
|
+
if local_block_tables is not None
|
|
303
|
+
else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
|
|
286
304
|
)
|
|
287
305
|
|
|
288
306
|
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
289
|
-
for b_idx in range(
|
|
307
|
+
for b_idx in range(batch_size):
|
|
290
308
|
decoding_step = cache_position[b_idx].item()
|
|
291
309
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
292
310
|
raise ValueError(
|
|
293
311
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
294
312
|
)
|
|
295
313
|
|
|
296
|
-
if
|
|
297
|
-
self.dec_attn_mask[b_idx]
|
|
298
|
-
|
|
314
|
+
if self.rbln_config.use_position_ids:
|
|
315
|
+
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
316
|
+
|
|
317
|
+
if self.batch_size < block_tables.shape[0]:
|
|
318
|
+
block_tables = block_tables[: self.batch_size]
|
|
319
|
+
|
|
320
|
+
if self.dec_attn_mask is not None and self.batch_size < self.dec_attn_mask.shape[0]:
|
|
321
|
+
self.dec_attn_mask = self.dec_attn_mask[: self.batch_size]
|
|
299
322
|
else:
|
|
300
|
-
|
|
323
|
+
if is_external_block_tables:
|
|
324
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
|
325
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
326
|
+
else:
|
|
327
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
301
328
|
|
|
302
329
|
attention_mask = self.dec_attn_mask
|
|
303
330
|
|
|
304
|
-
|
|
331
|
+
outputs = super().forward(
|
|
305
332
|
inputs,
|
|
306
333
|
cache_position,
|
|
307
334
|
block_tables,
|
|
@@ -310,15 +337,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
310
337
|
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
311
338
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
312
339
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
340
|
+
out=self.out_buffers,
|
|
313
341
|
)
|
|
314
342
|
|
|
315
|
-
|
|
343
|
+
if self.rbln_config.output_hidden_states:
|
|
344
|
+
return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
|
|
345
|
+
else:
|
|
346
|
+
return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)
|
|
316
347
|
|
|
317
348
|
def _prepare_prefill_inputs(
|
|
318
349
|
self,
|
|
319
350
|
inputs: torch.Tensor,
|
|
320
351
|
cache_position: Optional[torch.Tensor] = None,
|
|
321
352
|
attention_mask: Optional[torch.Tensor] = None,
|
|
353
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
322
354
|
position_embed: Optional[torch.Tensor] = None,
|
|
323
355
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
324
356
|
):
|
|
@@ -328,9 +360,27 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
328
360
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
329
361
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
330
362
|
if attention_mask is not None:
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
363
|
+
if attention_mask.dim() != 1:
|
|
364
|
+
raise ValueError("attention_mask must be a 1D tensor.")
|
|
365
|
+
|
|
366
|
+
mask_bool = attention_mask.to(dtype=torch.bool)
|
|
367
|
+
if (~mask_bool).any():
|
|
368
|
+
indice_one = torch.nonzero(mask_bool, as_tuple=False)
|
|
369
|
+
if indice_one.numel() == 0:
|
|
370
|
+
raise ValueError("attention_mask with padding must include at least one real token.")
|
|
371
|
+
first_one_idx, last_one_idx = int(indice_one[0].item()), int(indice_one[-1].item())
|
|
372
|
+
if last_one_idx - first_one_idx + 1 != mask_bool.sum():
|
|
373
|
+
raise ValueError(
|
|
374
|
+
"attention_mask must group all 1s together (e.g. 000111 or 1111000). "
|
|
375
|
+
"Zeros between real tokens like 101010 are not supported."
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if self.rbln_config.can_generate and not mask_bool[first_one_idx:].all():
|
|
379
|
+
raise ValueError("attention_mask must be left padded for generation.")
|
|
380
|
+
|
|
381
|
+
inputs = inputs[:, mask_bool]
|
|
382
|
+
position_embed = None if position_embed is None else position_embed[:, :, :, mask_bool, :]
|
|
383
|
+
token_type_ids = None if token_type_ids is None else token_type_ids[:, mask_bool]
|
|
334
384
|
|
|
335
385
|
query_length = inputs.shape[1]
|
|
336
386
|
if query_length > self.rbln_config.max_seq_len:
|
|
@@ -339,17 +389,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
339
389
|
)
|
|
340
390
|
|
|
341
391
|
# Initialize attention mask for chunked processing
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
392
|
+
if self.rbln_config.use_attention_mask:
|
|
393
|
+
if self.rbln_config.use_position_ids:
|
|
394
|
+
chunked_attention_mask = torch.zeros(
|
|
395
|
+
1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
chunked_attention_mask = torch.zeros(
|
|
399
|
+
1,
|
|
400
|
+
1,
|
|
401
|
+
self.rbln_config.prefill_chunk_size,
|
|
402
|
+
self.rbln_config.max_seq_len,
|
|
403
|
+
dtype=self.rbln_config.torch_dtype,
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
chunked_attention_mask = None
|
|
353
407
|
|
|
354
408
|
cache_position = (
|
|
355
409
|
torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
|
|
@@ -367,7 +421,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
367
421
|
cache_position = F.pad(cache_position, (0, padding_size))
|
|
368
422
|
|
|
369
423
|
# Overwrite position_ids and padded_cache_lengths
|
|
370
|
-
|
|
424
|
+
if self.rbln_config.use_position_ids and position_ids is None:
|
|
425
|
+
position_ids = cache_position.clone()
|
|
426
|
+
else:
|
|
427
|
+
position_ids = position_ids
|
|
428
|
+
|
|
371
429
|
padded_cache_lengths = 0
|
|
372
430
|
|
|
373
431
|
return (
|
|
@@ -381,6 +439,68 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
381
439
|
token_type_ids,
|
|
382
440
|
)
|
|
383
441
|
|
|
442
|
+
def _prepare_prefill_outputs(
|
|
443
|
+
self,
|
|
444
|
+
query_length: int,
|
|
445
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
446
|
+
):
|
|
447
|
+
# Prepare out buffers
|
|
448
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
449
|
+
padded_input_length = query_length + padding_size
|
|
450
|
+
padded_mask_length = (
|
|
451
|
+
attention_mask.shape[-1] + padding_size if attention_mask is not None else padded_input_length
|
|
452
|
+
)
|
|
453
|
+
out_buffers = [[] for _ in range(padded_input_length // self.rbln_config.prefill_chunk_size)]
|
|
454
|
+
|
|
455
|
+
valid_start_index = (
|
|
456
|
+
int(torch.nonzero(attention_mask, as_tuple=False)[0][0].item()) if attention_mask is not None else 0
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
if self.logits_last_dim is None:
|
|
460
|
+
logits_last_dim = self.config.vocab_size if self.rbln_config.can_generate else self.config.hidden_size
|
|
461
|
+
else:
|
|
462
|
+
logits_last_dim = self.logits_last_dim
|
|
463
|
+
|
|
464
|
+
# Prepare logits buffer
|
|
465
|
+
logits_size = (
|
|
466
|
+
1,
|
|
467
|
+
1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
|
|
468
|
+
logits_last_dim,
|
|
469
|
+
)
|
|
470
|
+
output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
|
|
471
|
+
|
|
472
|
+
if self.rbln_config.logits_to_keep == 1:
|
|
473
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
474
|
+
out_buffers[i].append(output_logits)
|
|
475
|
+
else:
|
|
476
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
477
|
+
s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
|
|
478
|
+
out_buffers[i].append(output_logits[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size])
|
|
479
|
+
|
|
480
|
+
# Prepare output hidden states
|
|
481
|
+
output_hidden_states = None
|
|
482
|
+
if self.rbln_config.output_hidden_states:
|
|
483
|
+
hidden_states_size = (
|
|
484
|
+
1,
|
|
485
|
+
padded_mask_length,
|
|
486
|
+
self.config.hidden_size,
|
|
487
|
+
)
|
|
488
|
+
output_hidden_states = [
|
|
489
|
+
torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
|
|
490
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
491
|
+
]
|
|
492
|
+
|
|
493
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
494
|
+
s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
|
|
495
|
+
out_buffers[i].extend(
|
|
496
|
+
[
|
|
497
|
+
hidden_states_buffer[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size]
|
|
498
|
+
for hidden_states_buffer in output_hidden_states
|
|
499
|
+
]
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
return out_buffers, output_logits, output_hidden_states
|
|
503
|
+
|
|
384
504
|
def prefill_forward(
|
|
385
505
|
self,
|
|
386
506
|
inputs: torch.Tensor,
|
|
@@ -389,6 +509,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
389
509
|
batch_idx: Optional[int] = None,
|
|
390
510
|
block_tables: Optional[torch.Tensor] = None,
|
|
391
511
|
is_external_block_tables: Optional[bool] = None,
|
|
512
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
392
513
|
position_embed: Optional[torch.Tensor] = None,
|
|
393
514
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
394
515
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
@@ -421,9 +542,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
421
542
|
query_length,
|
|
422
543
|
token_type_ids,
|
|
423
544
|
) = self._prepare_prefill_inputs(
|
|
424
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
545
|
+
inputs, cache_position, attention_mask, position_ids, position_embed, token_type_ids=token_type_ids
|
|
425
546
|
)
|
|
426
547
|
|
|
548
|
+
out_buffers, output_logits, output_hidden_states = self._prepare_prefill_outputs(query_length, attention_mask)
|
|
549
|
+
|
|
427
550
|
# Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
|
|
428
551
|
prefix_cached_len = cache_position[0][0].item()
|
|
429
552
|
if prefix_cached_len > 0:
|
|
@@ -432,11 +555,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
432
555
|
"Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
|
|
433
556
|
)
|
|
434
557
|
if self.rbln_config.use_attention_mask:
|
|
435
|
-
|
|
558
|
+
if self.rbln_config.use_position_ids:
|
|
559
|
+
chunked_attention_mask[:, :prefix_cached_len] = 1
|
|
560
|
+
else:
|
|
561
|
+
chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
|
|
436
562
|
|
|
437
563
|
# Process input in chunks of size `prefill_chunk_size`
|
|
438
|
-
|
|
439
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
564
|
+
for i, step in enumerate(range(0, query_length, self.rbln_config.prefill_chunk_size)):
|
|
440
565
|
s, e = step, step + self.rbln_config.prefill_chunk_size
|
|
441
566
|
# Extract the current chunk of inputs, cache positions, position ids, and position embeddings
|
|
442
567
|
input_chunk = inputs[:, s:e]
|
|
@@ -445,17 +570,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
445
570
|
position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
|
|
446
571
|
|
|
447
572
|
# Update attention mask to ensure proper causal behavior
|
|
448
|
-
if self.rbln_config.use_attention_mask
|
|
449
|
-
if
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
573
|
+
if self.rbln_config.use_attention_mask:
|
|
574
|
+
if self.rbln_config.use_position_ids:
|
|
575
|
+
if step > 0: # update previous chunk
|
|
576
|
+
# Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
|
|
577
|
+
prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
|
|
578
|
+
prev_chunk_end = s + prefix_cached_len
|
|
579
|
+
chunked_attention_mask[:, prev_chunk_start:prev_chunk_end] = 1
|
|
580
|
+
|
|
581
|
+
current_chunk_start = s + prefix_cached_len
|
|
582
|
+
current_chunk_end = min(e, query_length) + prefix_cached_len
|
|
583
|
+
if current_chunk_end > current_chunk_start:
|
|
584
|
+
chunked_attention_mask[:, current_chunk_start:current_chunk_end] = 1
|
|
585
|
+
|
|
586
|
+
else:
|
|
587
|
+
if step > 0: # update previous chunk
|
|
588
|
+
# Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
|
|
589
|
+
prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
|
|
590
|
+
prev_chunk_end = s + prefix_cached_len
|
|
591
|
+
chunked_attention_mask[:, :, :, prev_chunk_start:prev_chunk_end] = 1
|
|
592
|
+
|
|
593
|
+
current_chunk_start = s + prefix_cached_len
|
|
594
|
+
current_chunk_end = e + prefix_cached_len
|
|
595
|
+
chunked_attention_mask[:, :, :, current_chunk_start:current_chunk_end] = self.causal_mask
|
|
459
596
|
|
|
460
597
|
# Calculate query position if needed
|
|
461
598
|
if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
|
|
@@ -468,7 +605,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
468
605
|
query_position = None
|
|
469
606
|
|
|
470
607
|
# Forward pass for the current chunk
|
|
471
|
-
|
|
608
|
+
_ = super().forward(
|
|
472
609
|
input_chunk,
|
|
473
610
|
cache_pos_chunk,
|
|
474
611
|
block_tables,
|
|
@@ -478,31 +615,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
478
615
|
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
479
616
|
position_ids_chunk,
|
|
480
617
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
481
|
-
out=
|
|
618
|
+
out=out_buffers[i],
|
|
482
619
|
)
|
|
483
|
-
output_logits.append(output_logit)
|
|
484
620
|
|
|
485
621
|
# Aggregate output_logits
|
|
486
|
-
|
|
487
|
-
if self.rbln_config.logits_to_keep
|
|
488
|
-
output_logits = output_logits
|
|
622
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
623
|
+
if self.rbln_config.logits_to_keep == 1:
|
|
624
|
+
output_logits = output_logits
|
|
625
|
+
elif self.rbln_config.logits_to_keep > 1:
|
|
626
|
+
output_logits = output_logits[:, -padding_size - self.rbln_config.logits_to_keep : -padding_size, :]
|
|
489
627
|
else:
|
|
490
|
-
output_logits = output_logits[:,
|
|
491
|
-
# index copy for masked output_logits
|
|
492
|
-
if attention_mask is not None:
|
|
493
|
-
new_output_logits = torch.full(
|
|
494
|
-
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
495
|
-
fill_value=1e-10,
|
|
496
|
-
dtype=output_logits.dtype,
|
|
497
|
-
)
|
|
498
|
-
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
499
|
-
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
628
|
+
output_logits = output_logits[:, :-padding_size, :]
|
|
500
629
|
|
|
501
|
-
|
|
630
|
+
all_hidden_states = None
|
|
631
|
+
if self.rbln_config.output_hidden_states:
|
|
632
|
+
all_hidden_states = [
|
|
633
|
+
output_hidden_state[:, :-padding_size, :] for output_hidden_state in output_hidden_states
|
|
634
|
+
]
|
|
635
|
+
all_hidden_states = tuple(all_hidden_states)
|
|
502
636
|
|
|
503
637
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
504
638
|
if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
505
|
-
self.
|
|
506
|
-
|
|
639
|
+
if self.rbln_config.use_position_ids:
|
|
640
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
641
|
+
else:
|
|
642
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
|
643
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
507
644
|
|
|
508
|
-
return RBLNDecoderOnlyOutput(
|
|
645
|
+
return RBLNDecoderOnlyOutput(
|
|
646
|
+
logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
|
|
647
|
+
)
|
|
@@ -142,7 +142,7 @@ class LoRALinear(nn.Module):
|
|
|
142
142
|
padded_lora_a = []
|
|
143
143
|
padded_lora_b = []
|
|
144
144
|
|
|
145
|
-
for
|
|
145
|
+
for _, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
|
|
146
146
|
current_rank = lora_a.shape[0]
|
|
147
147
|
if current_rank < max_rank:
|
|
148
148
|
# Pad with zeros
|