optimum-rbln 0.9.3__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__version__.py +2 -2
  2. optimum/rbln/configuration_utils.py +12 -4
  3. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  4. optimum/rbln/diffusers/models/controlnet.py +1 -1
  5. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
  6. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  11. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  12. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  13. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  14. optimum/rbln/modeling_base.py +12 -7
  15. optimum/rbln/transformers/modeling_attention_utils.py +4 -4
  16. optimum/rbln/transformers/modeling_outputs.py +1 -0
  17. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  18. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  19. optimum/rbln/transformers/models/colpali/modeling_colpali.py +1 -1
  20. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  21. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +4 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  23. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +201 -62
  25. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +106 -36
  27. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  28. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  29. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +43 -26
  30. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  31. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +0 -1
  32. optimum/rbln/transformers/models/llava/modeling_llava.py +1 -1
  33. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -1
  34. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  35. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  36. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  37. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -6
  38. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +4 -4
  39. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  40. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  41. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -2
  42. optimum/rbln/transformers/models/swin/modeling_swin.py +3 -3
  43. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  44. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +9 -8
  45. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -2
  46. optimum/rbln/utils/import_utils.py +7 -1
  47. optimum/rbln/utils/submodule.py +3 -1
  48. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +1 -1
  49. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +52 -52
  50. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.9.3.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from collections import deque
16
- from typing import Any, Optional
16
+ from typing import TYPE_CHECKING, Any, Optional
17
17
 
18
18
  import rebel
19
19
  import torch
@@ -24,6 +24,10 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
24
24
  from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
25
 
26
26
 
27
+ if TYPE_CHECKING:
28
+ from transformers.configuration_utils import PreTrainedConfig
29
+
30
+
27
31
  class RBLNPageTableManager:
28
32
  EMPTY_BLOCK = -1
29
33
  NO_BLOCKS_ERROR = (
@@ -173,20 +177,23 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
173
177
  dec_attn_mask: torch.Tensor,
174
178
  page_table_manager: RBLNPageTableManager,
175
179
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
176
- out_buffers: Optional[torch.Tensor] = None,
180
+ config: "PreTrainedConfig" = None,
181
+ logits_last_dim: Optional[int] = None,
177
182
  **kwargs: Any,
178
183
  ) -> None:
179
184
  super().__init__(runtime, **kwargs)
180
185
  self.phase = phase
181
186
  self.batch_size = batch_size
182
187
  self.rbln_config = rbln_config
188
+ self.config = config
189
+ self.logits_last_dim = logits_last_dim
183
190
 
184
191
  # shared resources between prefill and decode phase
185
192
  self.dec_attn_mask = dec_attn_mask
186
193
  self.page_table_manager = page_table_manager
194
+ self.out_buffers = None
187
195
 
188
196
  if self.phase == "prefill":
189
- self.out_buffers = out_buffers
190
197
  self.causal_mask = 1 - torch.triu(
191
198
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
192
199
  )
@@ -280,28 +287,48 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
280
287
  if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
281
288
  raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
282
289
 
283
- if self.batch_size != cache_position.shape[0]:
290
+ batch_size = inputs.shape[0]
291
+ if batch_size != self.batch_size:
284
292
  raise RuntimeError(
285
- f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
293
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
294
+ )
295
+
296
+ if batch_size != cache_position.shape[0]:
297
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
298
+
299
+ if self.rbln_config.use_local_attention:
300
+ local_block_tables = (
301
+ local_block_tables
302
+ if local_block_tables is not None
303
+ else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
286
304
  )
287
305
 
288
306
  if self.rbln_config.use_attention_mask and attention_mask is None:
289
- for b_idx in range(self.batch_size):
307
+ for b_idx in range(batch_size):
290
308
  decoding_step = cache_position[b_idx].item()
291
309
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
292
310
  raise ValueError(
293
311
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
294
312
  )
295
313
 
296
- if is_external_block_tables:
297
- self.dec_attn_mask[b_idx].fill_(0)
298
- self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
314
+ if self.rbln_config.use_position_ids:
315
+ self.dec_attn_mask[b_idx, decoding_step] = 1
316
+
317
+ if self.batch_size < block_tables.shape[0]:
318
+ block_tables = block_tables[: self.batch_size]
319
+
320
+ if self.dec_attn_mask is not None and self.batch_size < self.dec_attn_mask.shape[0]:
321
+ self.dec_attn_mask = self.dec_attn_mask[: self.batch_size]
299
322
  else:
300
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
323
+ if is_external_block_tables:
324
+ self.dec_attn_mask[b_idx].fill_(0)
325
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
326
+ else:
327
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
301
328
 
302
329
  attention_mask = self.dec_attn_mask
303
330
 
304
- logits = super().forward(
331
+ outputs = super().forward(
305
332
  inputs,
306
333
  cache_position,
307
334
  block_tables,
@@ -310,15 +337,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
310
337
  attention_mask if self.rbln_config.use_attention_mask else None,
311
338
  position_ids if self.rbln_config.use_position_ids else None,
312
339
  lora_int_ids if self.rbln_config.use_lora else None,
340
+ out=self.out_buffers,
313
341
  )
314
342
 
315
- return RBLNDecoderOnlyOutput(logits=logits)
343
+ if self.rbln_config.output_hidden_states:
344
+ return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
345
+ else:
346
+ return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)
316
347
 
317
348
  def _prepare_prefill_inputs(
318
349
  self,
319
350
  inputs: torch.Tensor,
320
351
  cache_position: Optional[torch.Tensor] = None,
321
352
  attention_mask: Optional[torch.Tensor] = None,
353
+ position_ids: Optional[torch.Tensor] = None,
322
354
  position_embed: Optional[torch.Tensor] = None,
323
355
  token_type_ids: Optional[torch.Tensor] = None,
324
356
  ):
@@ -328,9 +360,27 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
328
360
  # Handle continuous batching in a compiled graph by extracting valid inputs
329
361
  # If an attention mask is provided, select only the valid (non-masked) inputs
330
362
  if attention_mask is not None:
331
- inputs = inputs[:, attention_mask.bool()]
332
- position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
333
- token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
363
+ if attention_mask.dim() != 1:
364
+ raise ValueError("attention_mask must be a 1D tensor.")
365
+
366
+ mask_bool = attention_mask.to(dtype=torch.bool)
367
+ if (~mask_bool).any():
368
+ indice_one = torch.nonzero(mask_bool, as_tuple=False)
369
+ if indice_one.numel() == 0:
370
+ raise ValueError("attention_mask with padding must include at least one real token.")
371
+ first_one_idx, last_one_idx = int(indice_one[0].item()), int(indice_one[-1].item())
372
+ if last_one_idx - first_one_idx + 1 != mask_bool.sum():
373
+ raise ValueError(
374
+ "attention_mask must group all 1s together (e.g. 000111 or 1111000). "
375
+ "Zeros between real tokens like 101010 are not supported."
376
+ )
377
+
378
+ if self.rbln_config.can_generate and not mask_bool[first_one_idx:].all():
379
+ raise ValueError("attention_mask must be left padded for generation.")
380
+
381
+ inputs = inputs[:, mask_bool]
382
+ position_embed = None if position_embed is None else position_embed[:, :, :, mask_bool, :]
383
+ token_type_ids = None if token_type_ids is None else token_type_ids[:, mask_bool]
334
384
 
335
385
  query_length = inputs.shape[1]
336
386
  if query_length > self.rbln_config.max_seq_len:
@@ -339,17 +389,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
339
389
  )
340
390
 
341
391
  # Initialize attention mask for chunked processing
342
- chunked_attention_mask = (
343
- torch.zeros(
344
- 1,
345
- 1,
346
- self.rbln_config.prefill_chunk_size,
347
- self.rbln_config.max_seq_len,
348
- dtype=self.rbln_config.torch_dtype,
349
- )
350
- if self.rbln_config.use_attention_mask
351
- else None
352
- )
392
+ if self.rbln_config.use_attention_mask:
393
+ if self.rbln_config.use_position_ids:
394
+ chunked_attention_mask = torch.zeros(
395
+ 1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
396
+ )
397
+ else:
398
+ chunked_attention_mask = torch.zeros(
399
+ 1,
400
+ 1,
401
+ self.rbln_config.prefill_chunk_size,
402
+ self.rbln_config.max_seq_len,
403
+ dtype=self.rbln_config.torch_dtype,
404
+ )
405
+ else:
406
+ chunked_attention_mask = None
353
407
 
354
408
  cache_position = (
355
409
  torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
@@ -367,7 +421,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
367
421
  cache_position = F.pad(cache_position, (0, padding_size))
368
422
 
369
423
  # Overwrite position_ids and padded_cache_lengths
370
- position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
424
+ if self.rbln_config.use_position_ids and position_ids is None:
425
+ position_ids = cache_position.clone()
426
+ else:
427
+ position_ids = position_ids
428
+
371
429
  padded_cache_lengths = 0
372
430
 
373
431
  return (
@@ -381,6 +439,68 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
381
439
  token_type_ids,
382
440
  )
383
441
 
442
+ def _prepare_prefill_outputs(
443
+ self,
444
+ query_length: int,
445
+ attention_mask: Optional[torch.Tensor] = None,
446
+ ):
447
+ # Prepare out buffers
448
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
449
+ padded_input_length = query_length + padding_size
450
+ padded_mask_length = (
451
+ attention_mask.shape[-1] + padding_size if attention_mask is not None else padded_input_length
452
+ )
453
+ out_buffers = [[] for _ in range(padded_input_length // self.rbln_config.prefill_chunk_size)]
454
+
455
+ valid_start_index = (
456
+ int(torch.nonzero(attention_mask, as_tuple=False)[0][0].item()) if attention_mask is not None else 0
457
+ )
458
+
459
+ if self.logits_last_dim is None:
460
+ logits_last_dim = self.config.vocab_size if self.rbln_config.can_generate else self.config.hidden_size
461
+ else:
462
+ logits_last_dim = self.logits_last_dim
463
+
464
+ # Prepare logits buffer
465
+ logits_size = (
466
+ 1,
467
+ 1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
468
+ logits_last_dim,
469
+ )
470
+ output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
471
+
472
+ if self.rbln_config.logits_to_keep == 1:
473
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
474
+ out_buffers[i].append(output_logits)
475
+ else:
476
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
477
+ s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
478
+ out_buffers[i].append(output_logits[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size])
479
+
480
+ # Prepare output hidden states
481
+ output_hidden_states = None
482
+ if self.rbln_config.output_hidden_states:
483
+ hidden_states_size = (
484
+ 1,
485
+ padded_mask_length,
486
+ self.config.hidden_size,
487
+ )
488
+ output_hidden_states = [
489
+ torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
490
+ for _ in range(self.config.num_hidden_layers + 1)
491
+ ]
492
+
493
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
494
+ s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
495
+ out_buffers[i].extend(
496
+ [
497
+ hidden_states_buffer[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size]
498
+ for hidden_states_buffer in output_hidden_states
499
+ ]
500
+ )
501
+
502
+ return out_buffers, output_logits, output_hidden_states
503
+
384
504
  def prefill_forward(
385
505
  self,
386
506
  inputs: torch.Tensor,
@@ -389,6 +509,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
389
509
  batch_idx: Optional[int] = None,
390
510
  block_tables: Optional[torch.Tensor] = None,
391
511
  is_external_block_tables: Optional[bool] = None,
512
+ position_ids: Optional[torch.Tensor] = None,
392
513
  position_embed: Optional[torch.Tensor] = None,
393
514
  token_type_ids: Optional[torch.Tensor] = None,
394
515
  local_block_tables: Optional[torch.Tensor] = None,
@@ -421,9 +542,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
421
542
  query_length,
422
543
  token_type_ids,
423
544
  ) = self._prepare_prefill_inputs(
424
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
545
+ inputs, cache_position, attention_mask, position_ids, position_embed, token_type_ids=token_type_ids
425
546
  )
426
547
 
548
+ out_buffers, output_logits, output_hidden_states = self._prepare_prefill_outputs(query_length, attention_mask)
549
+
427
550
  # Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
428
551
  prefix_cached_len = cache_position[0][0].item()
429
552
  if prefix_cached_len > 0:
@@ -432,11 +555,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
432
555
  "Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
433
556
  )
434
557
  if self.rbln_config.use_attention_mask:
435
- chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
558
+ if self.rbln_config.use_position_ids:
559
+ chunked_attention_mask[:, :prefix_cached_len] = 1
560
+ else:
561
+ chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
436
562
 
437
563
  # Process input in chunks of size `prefill_chunk_size`
438
- output_logits = []
439
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
564
+ for i, step in enumerate(range(0, query_length, self.rbln_config.prefill_chunk_size)):
440
565
  s, e = step, step + self.rbln_config.prefill_chunk_size
441
566
  # Extract the current chunk of inputs, cache positions, position ids, and position embeddings
442
567
  input_chunk = inputs[:, s:e]
@@ -445,17 +570,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
445
570
  position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
446
571
 
447
572
  # Update attention mask to ensure proper causal behavior
448
- if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
449
- if step > 0: # update previous chunk
450
- chunked_attention_mask[
451
- :,
452
- :,
453
- :,
454
- s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
455
- - self.rbln_config.prefill_chunk_size
456
- + prefix_cached_len,
457
- ] = 1
458
- chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
573
+ if self.rbln_config.use_attention_mask:
574
+ if self.rbln_config.use_position_ids:
575
+ if step > 0: # update previous chunk
576
+ # Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
577
+ prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
578
+ prev_chunk_end = s + prefix_cached_len
579
+ chunked_attention_mask[:, prev_chunk_start:prev_chunk_end] = 1
580
+
581
+ current_chunk_start = s + prefix_cached_len
582
+ current_chunk_end = min(e, query_length) + prefix_cached_len
583
+ if current_chunk_end > current_chunk_start:
584
+ chunked_attention_mask[:, current_chunk_start:current_chunk_end] = 1
585
+
586
+ else:
587
+ if step > 0: # update previous chunk
588
+ # Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
589
+ prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
590
+ prev_chunk_end = s + prefix_cached_len
591
+ chunked_attention_mask[:, :, :, prev_chunk_start:prev_chunk_end] = 1
592
+
593
+ current_chunk_start = s + prefix_cached_len
594
+ current_chunk_end = e + prefix_cached_len
595
+ chunked_attention_mask[:, :, :, current_chunk_start:current_chunk_end] = self.causal_mask
459
596
 
460
597
  # Calculate query position if needed
461
598
  if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
@@ -468,7 +605,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
468
605
  query_position = None
469
606
 
470
607
  # Forward pass for the current chunk
471
- output_logit = super().forward(
608
+ _ = super().forward(
472
609
  input_chunk,
473
610
  cache_pos_chunk,
474
611
  block_tables,
@@ -478,31 +615,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
478
615
  chunked_attention_mask if self.rbln_config.use_attention_mask else None,
479
616
  position_ids_chunk,
480
617
  lora_int_ids if self.rbln_config.use_lora else None,
481
- out=self.out_buffers,
618
+ out=out_buffers[i],
482
619
  )
483
- output_logits.append(output_logit)
484
620
 
485
621
  # Aggregate output_logits
486
- output_logits = torch.concat(output_logits, dim=-2)
487
- if self.rbln_config.logits_to_keep > 0:
488
- output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
622
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
623
+ if self.rbln_config.logits_to_keep == 1:
624
+ output_logits = output_logits
625
+ elif self.rbln_config.logits_to_keep > 1:
626
+ output_logits = output_logits[:, -padding_size - self.rbln_config.logits_to_keep : -padding_size, :]
489
627
  else:
490
- output_logits = output_logits[:, :query_length, :]
491
- # index copy for masked output_logits
492
- if attention_mask is not None:
493
- new_output_logits = torch.full(
494
- (1, attention_mask.shape[-1], output_logits.shape[-1]),
495
- fill_value=1e-10,
496
- dtype=output_logits.dtype,
497
- )
498
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
499
- new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
628
+ output_logits = output_logits[:, :-padding_size, :]
500
629
 
501
- output_logits = new_output_logits
630
+ all_hidden_states = None
631
+ if self.rbln_config.output_hidden_states:
632
+ all_hidden_states = [
633
+ output_hidden_state[:, :-padding_size, :] for output_hidden_state in output_hidden_states
634
+ ]
635
+ all_hidden_states = tuple(all_hidden_states)
502
636
 
503
637
  # Update decoder attention mask with processed KV-cache length from prefill phase
504
638
  if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
505
- self.dec_attn_mask[batch_idx].fill_(0)
506
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
639
+ if self.rbln_config.use_position_ids:
640
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
641
+ else:
642
+ self.dec_attn_mask[batch_idx].fill_(0)
643
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
507
644
 
508
- return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
645
+ return RBLNDecoderOnlyOutput(
646
+ logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
647
+ )
@@ -142,7 +142,7 @@ class LoRALinear(nn.Module):
142
142
  padded_lora_a = []
143
143
  padded_lora_b = []
144
144
 
145
- for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
145
+ for _, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
146
  current_rank = lora_a.shape[0]
147
147
  if current_rank < max_rank:
148
148
  # Pad with zeros