optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. optimum/rbln/__init__.py +12 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -6
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +12 -8
  38. optimum/rbln/transformers/configuration_generic.py +0 -27
  39. optimum/rbln/transformers/modeling_attention_utils.py +242 -109
  40. optimum/rbln/transformers/modeling_generic.py +2 -61
  41. optimum/rbln/transformers/modeling_outputs.py +1 -0
  42. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  43. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  44. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  45. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  46. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  47. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  48. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  49. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  50. optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
  51. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
  52. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
  53. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  54. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
  55. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
  56. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  57. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
  59. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  60. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  61. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  62. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
  63. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  64. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  67. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  68. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
  69. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  70. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  71. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  72. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  73. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  74. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  75. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  76. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
  77. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
  78. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
  79. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  80. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  81. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  82. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  83. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  84. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  85. optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
  86. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  87. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  88. optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
  89. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  90. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  91. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  92. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  93. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  94. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  95. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  96. optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
  97. optimum/rbln/utils/deprecation.py +213 -0
  98. optimum/rbln/utils/hub.py +14 -3
  99. optimum/rbln/utils/import_utils.py +7 -1
  100. optimum/rbln/utils/runtime_utils.py +32 -0
  101. optimum/rbln/utils/submodule.py +3 -1
  102. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
  103. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
  104. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
  105. optimum/rbln/utils/depreacate_utils.py +0 -16
  106. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
  107. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from collections import deque
16
- from typing import Any, Optional
16
+ from typing import TYPE_CHECKING, Any, Optional
17
17
 
18
18
  import rebel
19
19
  import torch
@@ -24,6 +24,10 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
24
24
  from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
25
 
26
26
 
27
+ if TYPE_CHECKING:
28
+ from transformers.configuration_utils import PreTrainedConfig
29
+
30
+
27
31
  class RBLNPageTableManager:
28
32
  EMPTY_BLOCK = -1
29
33
  NO_BLOCKS_ERROR = (
@@ -46,6 +50,12 @@ class RBLNPageTableManager:
46
50
  """
47
51
  If the block is empty (empty_block), allocates a block from the free_block_pool.
48
52
  """
53
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
54
+ raise IndexError(
55
+ f"Invalid index(batch_idx={batch_idx}, block_idx={block_idx}): \n \
56
+ BlockTable Shape(batch_axis, block_axis): {self.block_tables.shape}, BlockSize: {self.rbln_config.kvcache_block_size}"
57
+ )
58
+
49
59
  if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
50
60
  if self.free_block_pool:
51
61
  block = self.free_block_pool.popleft()
@@ -96,8 +106,6 @@ class RBLNPageTableManager:
96
106
  s, e = cache_position[0][0].item(), cache_position[0][-1].item()
97
107
  for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
98
108
  block_idx = position // self.rbln_config.kvcache_block_size
99
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
100
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
101
109
  self.update_block(batch_idx, block_idx)
102
110
 
103
111
  return self.replace_empty_block(self.block_tables[batch_idx])
@@ -169,20 +177,23 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
169
177
  dec_attn_mask: torch.Tensor,
170
178
  page_table_manager: RBLNPageTableManager,
171
179
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
172
- out_buffers: Optional[torch.Tensor] = None,
180
+ config: "PreTrainedConfig" = None,
181
+ logits_last_dim: Optional[int] = None,
173
182
  **kwargs: Any,
174
183
  ) -> None:
175
184
  super().__init__(runtime, **kwargs)
176
185
  self.phase = phase
177
186
  self.batch_size = batch_size
178
187
  self.rbln_config = rbln_config
188
+ self.config = config
189
+ self.logits_last_dim = logits_last_dim
179
190
 
180
191
  # shared resources between prefill and decode phase
181
192
  self.dec_attn_mask = dec_attn_mask
182
193
  self.page_table_manager = page_table_manager
194
+ self.out_buffers = None
183
195
 
184
196
  if self.phase == "prefill":
185
- self.out_buffers = out_buffers
186
197
  self.causal_mask = 1 - torch.triu(
187
198
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
199
  )
@@ -276,28 +287,48 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
276
287
  if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
277
288
  raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
278
289
 
279
- if self.batch_size != cache_position.shape[0]:
290
+ batch_size = inputs.shape[0]
291
+ if batch_size != self.batch_size:
280
292
  raise RuntimeError(
281
- f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
293
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
294
+ )
295
+
296
+ if batch_size != cache_position.shape[0]:
297
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
298
+
299
+ if self.rbln_config.use_local_attention:
300
+ local_block_tables = (
301
+ local_block_tables
302
+ if local_block_tables is not None
303
+ else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
282
304
  )
283
305
 
284
306
  if self.rbln_config.use_attention_mask and attention_mask is None:
285
- for b_idx in range(self.batch_size):
307
+ for b_idx in range(batch_size):
286
308
  decoding_step = cache_position[b_idx].item()
287
309
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
288
310
  raise ValueError(
289
311
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
290
312
  )
291
313
 
292
- if is_external_block_tables:
293
- self.dec_attn_mask[b_idx].fill_(0)
294
- self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
314
+ if self.rbln_config.use_position_ids:
315
+ self.dec_attn_mask[b_idx, decoding_step] = 1
316
+
317
+ if self.batch_size < block_tables.shape[0]:
318
+ block_tables = block_tables[: self.batch_size]
319
+
320
+ if self.dec_attn_mask is not None and self.batch_size < self.dec_attn_mask.shape[0]:
321
+ self.dec_attn_mask = self.dec_attn_mask[: self.batch_size]
295
322
  else:
296
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
323
+ if is_external_block_tables:
324
+ self.dec_attn_mask[b_idx].fill_(0)
325
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
326
+ else:
327
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
297
328
 
298
329
  attention_mask = self.dec_attn_mask
299
330
 
300
- logits = super().forward(
331
+ outputs = super().forward(
301
332
  inputs,
302
333
  cache_position,
303
334
  block_tables,
@@ -306,15 +337,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
306
337
  attention_mask if self.rbln_config.use_attention_mask else None,
307
338
  position_ids if self.rbln_config.use_position_ids else None,
308
339
  lora_int_ids if self.rbln_config.use_lora else None,
340
+ out=self.out_buffers,
309
341
  )
310
342
 
311
- return RBLNDecoderOnlyOutput(logits=logits)
343
+ if self.rbln_config.output_hidden_states:
344
+ return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
345
+ else:
346
+ return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)
312
347
 
313
348
  def _prepare_prefill_inputs(
314
349
  self,
315
350
  inputs: torch.Tensor,
316
351
  cache_position: Optional[torch.Tensor] = None,
317
352
  attention_mask: Optional[torch.Tensor] = None,
353
+ position_ids: Optional[torch.Tensor] = None,
318
354
  position_embed: Optional[torch.Tensor] = None,
319
355
  token_type_ids: Optional[torch.Tensor] = None,
320
356
  ):
@@ -324,9 +360,27 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
324
360
  # Handle continuous batching in a compiled graph by extracting valid inputs
325
361
  # If an attention mask is provided, select only the valid (non-masked) inputs
326
362
  if attention_mask is not None:
327
- inputs = inputs[:, attention_mask.bool()]
328
- position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
329
- token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
363
+ if attention_mask.dim() != 1:
364
+ raise ValueError("attention_mask must be a 1D tensor.")
365
+
366
+ mask_bool = attention_mask.to(dtype=torch.bool)
367
+ if (~mask_bool).any():
368
+ indice_one = torch.nonzero(mask_bool, as_tuple=False)
369
+ if indice_one.numel() == 0:
370
+ raise ValueError("attention_mask with padding must include at least one real token.")
371
+ first_one_idx, last_one_idx = int(indice_one[0].item()), int(indice_one[-1].item())
372
+ if last_one_idx - first_one_idx + 1 != mask_bool.sum():
373
+ raise ValueError(
374
+ "attention_mask must group all 1s together (e.g. 000111 or 1111000). "
375
+ "Zeros between real tokens like 101010 are not supported."
376
+ )
377
+
378
+ if self.rbln_config.can_generate and not mask_bool[first_one_idx:].all():
379
+ raise ValueError("attention_mask must be left padded for generation.")
380
+
381
+ inputs = inputs[:, mask_bool]
382
+ position_embed = None if position_embed is None else position_embed[:, :, :, mask_bool, :]
383
+ token_type_ids = None if token_type_ids is None else token_type_ids[:, mask_bool]
330
384
 
331
385
  query_length = inputs.shape[1]
332
386
  if query_length > self.rbln_config.max_seq_len:
@@ -335,17 +389,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
335
389
  )
336
390
 
337
391
  # Initialize attention mask for chunked processing
338
- chunked_attention_mask = (
339
- torch.zeros(
340
- 1,
341
- 1,
342
- self.rbln_config.prefill_chunk_size,
343
- self.rbln_config.max_seq_len,
344
- dtype=self.rbln_config.torch_dtype,
345
- )
346
- if self.rbln_config.use_attention_mask
347
- else None
348
- )
392
+ if self.rbln_config.use_attention_mask:
393
+ if self.rbln_config.use_position_ids:
394
+ chunked_attention_mask = torch.zeros(
395
+ 1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
396
+ )
397
+ else:
398
+ chunked_attention_mask = torch.zeros(
399
+ 1,
400
+ 1,
401
+ self.rbln_config.prefill_chunk_size,
402
+ self.rbln_config.max_seq_len,
403
+ dtype=self.rbln_config.torch_dtype,
404
+ )
405
+ else:
406
+ chunked_attention_mask = None
349
407
 
350
408
  cache_position = (
351
409
  torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
@@ -363,7 +421,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
363
421
  cache_position = F.pad(cache_position, (0, padding_size))
364
422
 
365
423
  # Overwrite position_ids and padded_cache_lengths
366
- position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
424
+ if self.rbln_config.use_position_ids and position_ids is None:
425
+ position_ids = cache_position.clone()
426
+ else:
427
+ position_ids = position_ids
428
+
367
429
  padded_cache_lengths = 0
368
430
 
369
431
  return (
@@ -377,6 +439,68 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
377
439
  token_type_ids,
378
440
  )
379
441
 
442
+ def _prepare_prefill_outputs(
443
+ self,
444
+ query_length: int,
445
+ attention_mask: Optional[torch.Tensor] = None,
446
+ ):
447
+ # Prepare out buffers
448
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
449
+ padded_input_length = query_length + padding_size
450
+ padded_mask_length = (
451
+ attention_mask.shape[-1] + padding_size if attention_mask is not None else padded_input_length
452
+ )
453
+ out_buffers = [[] for _ in range(padded_input_length // self.rbln_config.prefill_chunk_size)]
454
+
455
+ valid_start_index = (
456
+ int(torch.nonzero(attention_mask, as_tuple=False)[0][0].item()) if attention_mask is not None else 0
457
+ )
458
+
459
+ if self.logits_last_dim is None:
460
+ logits_last_dim = self.config.vocab_size if self.rbln_config.can_generate else self.config.hidden_size
461
+ else:
462
+ logits_last_dim = self.logits_last_dim
463
+
464
+ # Prepare logits buffer
465
+ logits_size = (
466
+ 1,
467
+ 1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
468
+ logits_last_dim,
469
+ )
470
+ output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
471
+
472
+ if self.rbln_config.logits_to_keep == 1:
473
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
474
+ out_buffers[i].append(output_logits)
475
+ else:
476
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
477
+ s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
478
+ out_buffers[i].append(output_logits[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size])
479
+
480
+ # Prepare output hidden states
481
+ output_hidden_states = None
482
+ if self.rbln_config.output_hidden_states:
483
+ hidden_states_size = (
484
+ 1,
485
+ padded_mask_length,
486
+ self.config.hidden_size,
487
+ )
488
+ output_hidden_states = [
489
+ torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
490
+ for _ in range(self.config.num_hidden_layers + 1)
491
+ ]
492
+
493
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
494
+ s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
495
+ out_buffers[i].extend(
496
+ [
497
+ hidden_states_buffer[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size]
498
+ for hidden_states_buffer in output_hidden_states
499
+ ]
500
+ )
501
+
502
+ return out_buffers, output_logits, output_hidden_states
503
+
380
504
  def prefill_forward(
381
505
  self,
382
506
  inputs: torch.Tensor,
@@ -385,6 +509,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
385
509
  batch_idx: Optional[int] = None,
386
510
  block_tables: Optional[torch.Tensor] = None,
387
511
  is_external_block_tables: Optional[bool] = None,
512
+ position_ids: Optional[torch.Tensor] = None,
388
513
  position_embed: Optional[torch.Tensor] = None,
389
514
  token_type_ids: Optional[torch.Tensor] = None,
390
515
  local_block_tables: Optional[torch.Tensor] = None,
@@ -417,9 +542,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
417
542
  query_length,
418
543
  token_type_ids,
419
544
  ) = self._prepare_prefill_inputs(
420
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
545
+ inputs, cache_position, attention_mask, position_ids, position_embed, token_type_ids=token_type_ids
421
546
  )
422
547
 
548
+ out_buffers, output_logits, output_hidden_states = self._prepare_prefill_outputs(query_length, attention_mask)
549
+
423
550
  # Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
424
551
  prefix_cached_len = cache_position[0][0].item()
425
552
  if prefix_cached_len > 0:
@@ -428,11 +555,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
428
555
  "Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
429
556
  )
430
557
  if self.rbln_config.use_attention_mask:
431
- chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
558
+ if self.rbln_config.use_position_ids:
559
+ chunked_attention_mask[:, :prefix_cached_len] = 1
560
+ else:
561
+ chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
432
562
 
433
563
  # Process input in chunks of size `prefill_chunk_size`
434
- output_logits = []
435
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
564
+ for i, step in enumerate(range(0, query_length, self.rbln_config.prefill_chunk_size)):
436
565
  s, e = step, step + self.rbln_config.prefill_chunk_size
437
566
  # Extract the current chunk of inputs, cache positions, position ids, and position embeddings
438
567
  input_chunk = inputs[:, s:e]
@@ -441,17 +570,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
441
570
  position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
442
571
 
443
572
  # Update attention mask to ensure proper causal behavior
444
- if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
445
- if step > 0: # update previous chunk
446
- chunked_attention_mask[
447
- :,
448
- :,
449
- :,
450
- s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
451
- - self.rbln_config.prefill_chunk_size
452
- + prefix_cached_len,
453
- ] = 1
454
- chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
573
+ if self.rbln_config.use_attention_mask:
574
+ if self.rbln_config.use_position_ids:
575
+ if step > 0: # update previous chunk
576
+ # Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
577
+ prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
578
+ prev_chunk_end = s + prefix_cached_len
579
+ chunked_attention_mask[:, prev_chunk_start:prev_chunk_end] = 1
580
+
581
+ current_chunk_start = s + prefix_cached_len
582
+ current_chunk_end = min(e, query_length) + prefix_cached_len
583
+ if current_chunk_end > current_chunk_start:
584
+ chunked_attention_mask[:, current_chunk_start:current_chunk_end] = 1
585
+
586
+ else:
587
+ if step > 0: # update previous chunk
588
+ # Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
589
+ prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
590
+ prev_chunk_end = s + prefix_cached_len
591
+ chunked_attention_mask[:, :, :, prev_chunk_start:prev_chunk_end] = 1
592
+
593
+ current_chunk_start = s + prefix_cached_len
594
+ current_chunk_end = e + prefix_cached_len
595
+ chunked_attention_mask[:, :, :, current_chunk_start:current_chunk_end] = self.causal_mask
455
596
 
456
597
  # Calculate query position if needed
457
598
  if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
@@ -464,7 +605,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
464
605
  query_position = None
465
606
 
466
607
  # Forward pass for the current chunk
467
- output_logit = super().forward(
608
+ _ = super().forward(
468
609
  input_chunk,
469
610
  cache_pos_chunk,
470
611
  block_tables,
@@ -474,31 +615,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
474
615
  chunked_attention_mask if self.rbln_config.use_attention_mask else None,
475
616
  position_ids_chunk,
476
617
  lora_int_ids if self.rbln_config.use_lora else None,
477
- out=self.out_buffers,
618
+ out=out_buffers[i],
478
619
  )
479
- output_logits.append(output_logit)
480
620
 
481
621
  # Aggregate output_logits
482
- output_logits = torch.concat(output_logits, dim=-2)
483
- if self.rbln_config.logits_to_keep > 0:
484
- output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
622
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
623
+ if self.rbln_config.logits_to_keep == 1:
624
+ output_logits = output_logits
625
+ elif self.rbln_config.logits_to_keep > 1:
626
+ output_logits = output_logits[:, -padding_size - self.rbln_config.logits_to_keep : -padding_size, :]
485
627
  else:
486
- output_logits = output_logits[:, :query_length, :]
487
- # index copy for masked output_logits
488
- if attention_mask is not None:
489
- new_output_logits = torch.full(
490
- (1, attention_mask.shape[-1], output_logits.shape[-1]),
491
- fill_value=1e-10,
492
- dtype=output_logits.dtype,
493
- )
494
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
495
- new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
628
+ output_logits = output_logits[:, :-padding_size, :]
496
629
 
497
- output_logits = new_output_logits
630
+ all_hidden_states = None
631
+ if self.rbln_config.output_hidden_states:
632
+ all_hidden_states = [
633
+ output_hidden_state[:, :-padding_size, :] for output_hidden_state in output_hidden_states
634
+ ]
635
+ all_hidden_states = tuple(all_hidden_states)
498
636
 
499
637
  # Update decoder attention mask with processed KV-cache length from prefill phase
500
638
  if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
501
- self.dec_attn_mask[batch_idx].fill_(0)
502
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
639
+ if self.rbln_config.use_position_ids:
640
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
641
+ else:
642
+ self.dec_attn_mask[batch_idx].fill_(0)
643
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
503
644
 
504
- return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
645
+ return RBLNDecoderOnlyOutput(
646
+ logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
647
+ )
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
16
16
 
17
17
  import torch
18
+ from transformers import GenerationConfig
18
19
  from transformers.generation.utils import GenerationMixin
20
+ from transformers.modeling_outputs import ModelOutput
19
21
 
20
22
 
21
23
  if TYPE_CHECKING:
@@ -91,20 +93,26 @@ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
91
93
  self,
92
94
  input_ids: torch.LongTensor,
93
95
  attention_mask: Optional[torch.LongTensor] = None,
94
- max_length: Optional[int] = None,
96
+ generation_config: Optional[GenerationConfig] = None,
95
97
  **kwargs,
96
- ):
98
+ ) -> Union[ModelOutput, torch.LongTensor]:
97
99
  """
98
100
  The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
101
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
99
102
 
100
103
  Args:
101
- input_ids: The input ids to the model.
102
- attention_mask: The attention mask to the model.
103
- max_length: The maximum length of the generated text.
104
- kwargs: Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
104
+ input_ids (torch.LongTensor): The input ids to the model.
105
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
106
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
107
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
108
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
109
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
110
+
111
+ Returns:
112
+ A ModelOutput (if return_dict_in_generate=True or when config.return_dict_in_generate=True) or a torch.LongTensor.
105
113
  """
106
- if max_length is not None:
107
- kwargs["max_length"] = max_length
114
+ if generation_config is not None:
115
+ kwargs["generation_config"] = generation_config
108
116
  if attention_mask is not None:
109
117
  kwargs["attention_mask"] = attention_mask
110
118
 
@@ -142,7 +142,7 @@ class LoRALinear(nn.Module):
142
142
  padded_lora_a = []
143
143
  padded_lora_b = []
144
144
 
145
- for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
145
+ for _, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
146
  current_rank = lora_a.shape[0]
147
147
  if current_rank < max_rank:
148
148
  # Pad with zeros