optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from collections import deque
16
- from typing import Any, Optional
16
+ from typing import TYPE_CHECKING, Any, Optional
17
17
 
18
18
  import rebel
19
19
  import torch
@@ -24,6 +24,10 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
24
24
  from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
25
25
 
26
26
 
27
+ if TYPE_CHECKING:
28
+ from transformers.configuration_utils import PreTrainedConfig
29
+
30
+
27
31
  class RBLNPageTableManager:
28
32
  EMPTY_BLOCK = -1
29
33
  NO_BLOCKS_ERROR = (
@@ -46,6 +50,12 @@ class RBLNPageTableManager:
46
50
  """
47
51
  If the block is empty (empty_block), allocates a block from the free_block_pool.
48
52
  """
53
+ if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
54
+ raise IndexError(
55
+ f"Invalid index(batch_idx={batch_idx}, block_idx={block_idx}): \n \
56
+ BlockTable Shape(batch_axis, block_axis): {self.block_tables.shape}, BlockSize: {self.rbln_config.kvcache_block_size}"
57
+ )
58
+
49
59
  if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
50
60
  if self.free_block_pool:
51
61
  block = self.free_block_pool.popleft()
@@ -96,8 +106,6 @@ class RBLNPageTableManager:
96
106
  s, e = cache_position[0][0].item(), cache_position[0][-1].item()
97
107
  for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
98
108
  block_idx = position // self.rbln_config.kvcache_block_size
99
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
100
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
101
109
  self.update_block(batch_idx, block_idx)
102
110
 
103
111
  return self.replace_empty_block(self.block_tables[batch_idx])
@@ -169,20 +177,23 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
169
177
  dec_attn_mask: torch.Tensor,
170
178
  page_table_manager: RBLNPageTableManager,
171
179
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
172
- out_buffers: Optional[torch.Tensor] = None,
180
+ config: Optional["PreTrainedConfig"] = None,
181
+ logits_last_dim: Optional[int] = None,
173
182
  **kwargs: Any,
174
183
  ) -> None:
175
184
  super().__init__(runtime, **kwargs)
176
185
  self.phase = phase
177
186
  self.batch_size = batch_size
178
187
  self.rbln_config = rbln_config
188
+ self.config = config
189
+ self.logits_last_dim = logits_last_dim
179
190
 
180
191
  # shared resources between prefill and decode phase
181
192
  self.dec_attn_mask = dec_attn_mask
182
193
  self.page_table_manager = page_table_manager
194
+ self.out_buffers = None
183
195
 
184
196
  if self.phase == "prefill":
185
- self.out_buffers = out_buffers
186
197
  self.causal_mask = 1 - torch.triu(
187
198
  torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
188
199
  )
@@ -276,28 +287,48 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
276
287
  if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
277
288
  raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
278
289
 
279
- if self.batch_size != cache_position.shape[0]:
290
+ batch_size = inputs.shape[0]
291
+ if batch_size != self.batch_size:
280
292
  raise RuntimeError(
281
- f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
293
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
294
+ )
295
+
296
+ if batch_size != cache_position.shape[0]:
297
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
298
+
299
+ if self.rbln_config.use_local_attention:
300
+ local_block_tables = (
301
+ local_block_tables
302
+ if local_block_tables is not None
303
+ else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
282
304
  )
283
305
 
284
306
  if self.rbln_config.use_attention_mask and attention_mask is None:
285
- for b_idx in range(self.batch_size):
307
+ for b_idx in range(batch_size):
286
308
  decoding_step = cache_position[b_idx].item()
287
309
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
288
310
  raise ValueError(
289
311
  f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
290
312
  )
291
313
 
292
- if is_external_block_tables:
293
- self.dec_attn_mask[b_idx].fill_(0)
294
- self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
314
+ if self.rbln_config.use_position_ids:
315
+ self.dec_attn_mask[b_idx, decoding_step] = 1
316
+
317
+ if self.batch_size < block_tables.shape[0]:
318
+ block_tables = block_tables[: self.batch_size]
319
+
320
+ if self.dec_attn_mask is not None and self.batch_size < self.dec_attn_mask.shape[0]:
321
+ self.dec_attn_mask = self.dec_attn_mask[: self.batch_size]
295
322
  else:
296
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
323
+ if is_external_block_tables:
324
+ self.dec_attn_mask[b_idx].fill_(0)
325
+ self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
326
+ else:
327
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
297
328
 
298
329
  attention_mask = self.dec_attn_mask
299
330
 
300
- logits = super().forward(
331
+ outputs = super().forward(
301
332
  inputs,
302
333
  cache_position,
303
334
  block_tables,
@@ -306,15 +337,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
306
337
  attention_mask if self.rbln_config.use_attention_mask else None,
307
338
  position_ids if self.rbln_config.use_position_ids else None,
308
339
  lora_int_ids if self.rbln_config.use_lora else None,
340
+ out=self.out_buffers,
309
341
  )
310
342
 
311
- return RBLNDecoderOnlyOutput(logits=logits)
343
+ if self.rbln_config.output_hidden_states:
344
+ return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
345
+ else:
346
+ return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)
312
347
 
313
348
  def _prepare_prefill_inputs(
314
349
  self,
315
350
  inputs: torch.Tensor,
316
351
  cache_position: Optional[torch.Tensor] = None,
317
352
  attention_mask: Optional[torch.Tensor] = None,
353
+ position_ids: Optional[torch.Tensor] = None,
318
354
  position_embed: Optional[torch.Tensor] = None,
319
355
  token_type_ids: Optional[torch.Tensor] = None,
320
356
  ):
@@ -324,9 +360,27 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
324
360
  # Handle continuous batching in a compiled graph by extracting valid inputs
325
361
  # If an attention mask is provided, select only the valid (non-masked) inputs
326
362
  if attention_mask is not None:
327
- inputs = inputs[:, attention_mask.bool()]
328
- position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
329
- token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
363
+ if attention_mask.dim() != 1:
364
+ raise ValueError("attention_mask must be a 1D tensor.")
365
+
366
+ mask_bool = attention_mask.to(dtype=torch.bool)
367
+ if (~mask_bool).any():
368
+ indice_one = torch.nonzero(mask_bool, as_tuple=False)
369
+ if indice_one.numel() == 0:
370
+ raise ValueError("attention_mask with padding must include at least one real token.")
371
+ first_one_idx, last_one_idx = int(indice_one[0].item()), int(indice_one[-1].item())
372
+ if last_one_idx - first_one_idx + 1 != mask_bool.sum():
373
+ raise ValueError(
374
+ "attention_mask must group all 1s together (e.g. 000111 or 1111000). "
375
+ "Zeros between real tokens like 101010 are not supported."
376
+ )
377
+
378
+ if self.rbln_config.can_generate and not mask_bool[first_one_idx:].all():
379
+ raise ValueError("attention_mask must be left padded for generation.")
380
+
381
+ inputs = inputs[:, mask_bool]
382
+ position_embed = None if position_embed is None else position_embed[:, :, :, mask_bool, :]
383
+ token_type_ids = None if token_type_ids is None else token_type_ids[:, mask_bool]
330
384
 
331
385
  query_length = inputs.shape[1]
332
386
  if query_length > self.rbln_config.max_seq_len:
@@ -335,17 +389,19 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
335
389
  )
336
390
 
337
391
  # Initialize attention mask for chunked processing
338
- chunked_attention_mask = (
339
- torch.zeros(
340
- 1,
341
- 1,
342
- self.rbln_config.prefill_chunk_size,
343
- self.rbln_config.max_seq_len,
344
- dtype=self.rbln_config.torch_dtype,
345
- )
346
- if self.rbln_config.use_attention_mask
347
- else None
348
- )
392
+ if self.rbln_config.use_attention_mask:
393
+ if self.rbln_config.use_position_ids:
394
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=self.rbln_config.dtype)
395
+ else:
396
+ chunked_attention_mask = torch.zeros(
397
+ 1,
398
+ 1,
399
+ self.rbln_config.prefill_chunk_size,
400
+ self.rbln_config.max_seq_len,
401
+ dtype=self.rbln_config.dtype,
402
+ )
403
+ else:
404
+ chunked_attention_mask = None
349
405
 
350
406
  cache_position = (
351
407
  torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
@@ -363,7 +419,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
363
419
  cache_position = F.pad(cache_position, (0, padding_size))
364
420
 
365
421
  # Overwrite position_ids and padded_cache_lengths
366
- position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
422
+ if self.rbln_config.use_position_ids and position_ids is None:
423
+ position_ids = cache_position.clone()
424
+ else:
425
+ position_ids = position_ids
426
+
367
427
  padded_cache_lengths = 0
368
428
 
369
429
  return (
@@ -377,6 +437,68 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
377
437
  token_type_ids,
378
438
  )
379
439
 
440
+ def _prepare_prefill_outputs(
441
+ self,
442
+ query_length: int,
443
+ attention_mask: Optional[torch.Tensor] = None,
444
+ ):
445
+ # Prepare out buffers
446
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
447
+ padded_input_length = query_length + padding_size
448
+ padded_mask_length = (
449
+ attention_mask.shape[-1] + padding_size if attention_mask is not None else padded_input_length
450
+ )
451
+ out_buffers = [[] for _ in range(padded_input_length // self.rbln_config.prefill_chunk_size)]
452
+
453
+ valid_start_index = (
454
+ int(torch.nonzero(attention_mask, as_tuple=False)[0][0].item()) if attention_mask is not None else 0
455
+ )
456
+
457
+ if self.logits_last_dim is None:
458
+ logits_last_dim = self.config.vocab_size if self.rbln_config.can_generate else self.config.hidden_size
459
+ else:
460
+ logits_last_dim = self.logits_last_dim
461
+
462
+ # Prepare logits buffer
463
+ logits_size = (
464
+ 1,
465
+ 1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
466
+ logits_last_dim,
467
+ )
468
+ output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
469
+
470
+ if self.rbln_config.logits_to_keep == 1:
471
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
472
+ out_buffers[i].append(output_logits)
473
+ else:
474
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
475
+ s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
476
+ out_buffers[i].append(output_logits[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size])
477
+
478
+ # Prepare output hidden states
479
+ output_hidden_states = None
480
+ if self.rbln_config.output_hidden_states:
481
+ hidden_states_size = (
482
+ 1,
483
+ padded_mask_length,
484
+ self.config.hidden_size,
485
+ )
486
+ output_hidden_states = [
487
+ torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
488
+ for _ in range(self.config.num_hidden_layers + 1)
489
+ ]
490
+
491
+ for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
492
+ s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
493
+ out_buffers[i].extend(
494
+ [
495
+ hidden_states_buffer[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size]
496
+ for hidden_states_buffer in output_hidden_states
497
+ ]
498
+ )
499
+
500
+ return out_buffers, output_logits, output_hidden_states
501
+
380
502
  def prefill_forward(
381
503
  self,
382
504
  inputs: torch.Tensor,
@@ -385,6 +507,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
385
507
  batch_idx: Optional[int] = None,
386
508
  block_tables: Optional[torch.Tensor] = None,
387
509
  is_external_block_tables: Optional[bool] = None,
510
+ position_ids: Optional[torch.Tensor] = None,
388
511
  position_embed: Optional[torch.Tensor] = None,
389
512
  token_type_ids: Optional[torch.Tensor] = None,
390
513
  local_block_tables: Optional[torch.Tensor] = None,
@@ -417,9 +540,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
417
540
  query_length,
418
541
  token_type_ids,
419
542
  ) = self._prepare_prefill_inputs(
420
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
543
+ inputs, cache_position, attention_mask, position_ids, position_embed, token_type_ids=token_type_ids
421
544
  )
422
545
 
546
+ out_buffers, output_logits, output_hidden_states = self._prepare_prefill_outputs(query_length, attention_mask)
547
+
423
548
  # Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
424
549
  prefix_cached_len = cache_position[0][0].item()
425
550
  if prefix_cached_len > 0:
@@ -428,11 +553,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
428
553
  "Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
429
554
  )
430
555
  if self.rbln_config.use_attention_mask:
431
- chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
556
+ if self.rbln_config.use_position_ids:
557
+ chunked_attention_mask[:, :prefix_cached_len] = 1
558
+ else:
559
+ chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
432
560
 
433
561
  # Process input in chunks of size `prefill_chunk_size`
434
- output_logits = []
435
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
562
+ for i, step in enumerate(range(0, query_length, self.rbln_config.prefill_chunk_size)):
436
563
  s, e = step, step + self.rbln_config.prefill_chunk_size
437
564
  # Extract the current chunk of inputs, cache positions, position ids, and position embeddings
438
565
  input_chunk = inputs[:, s:e]
@@ -441,17 +568,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
441
568
  position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
442
569
 
443
570
  # Update attention mask to ensure proper causal behavior
444
- if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
445
- if step > 0: # update previous chunk
446
- chunked_attention_mask[
447
- :,
448
- :,
449
- :,
450
- s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
451
- - self.rbln_config.prefill_chunk_size
452
- + prefix_cached_len,
453
- ] = 1
454
- chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
571
+ if self.rbln_config.use_attention_mask:
572
+ if self.rbln_config.use_position_ids:
573
+ if step > 0: # update previous chunk
574
+ # Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
575
+ prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
576
+ prev_chunk_end = s + prefix_cached_len
577
+ chunked_attention_mask[:, prev_chunk_start:prev_chunk_end] = 1
578
+
579
+ current_chunk_start = s + prefix_cached_len
580
+ current_chunk_end = min(e, query_length) + prefix_cached_len
581
+ if current_chunk_end > current_chunk_start:
582
+ chunked_attention_mask[:, current_chunk_start:current_chunk_end] = 1
583
+
584
+ else:
585
+ if step > 0: # update previous chunk
586
+ # Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
587
+ prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
588
+ prev_chunk_end = s + prefix_cached_len
589
+ chunked_attention_mask[:, :, :, prev_chunk_start:prev_chunk_end] = 1
590
+
591
+ current_chunk_start = s + prefix_cached_len
592
+ current_chunk_end = e + prefix_cached_len
593
+ chunked_attention_mask[:, :, :, current_chunk_start:current_chunk_end] = self.causal_mask
455
594
 
456
595
  # Calculate query position if needed
457
596
  if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
@@ -464,7 +603,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
464
603
  query_position = None
465
604
 
466
605
  # Forward pass for the current chunk
467
- output_logit = super().forward(
606
+ _ = super().forward(
468
607
  input_chunk,
469
608
  cache_pos_chunk,
470
609
  block_tables,
@@ -474,31 +613,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
474
613
  chunked_attention_mask if self.rbln_config.use_attention_mask else None,
475
614
  position_ids_chunk,
476
615
  lora_int_ids if self.rbln_config.use_lora else None,
477
- out=self.out_buffers,
616
+ out=out_buffers[i],
478
617
  )
479
- output_logits.append(output_logit)
480
618
 
481
619
  # Aggregate output_logits
482
- output_logits = torch.concat(output_logits, dim=-2)
483
- if self.rbln_config.logits_to_keep > 0:
484
- output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
620
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
621
+ if self.rbln_config.logits_to_keep == 1:
622
+ output_logits = output_logits
623
+ elif self.rbln_config.logits_to_keep > 1:
624
+ output_logits = output_logits[:, -padding_size - self.rbln_config.logits_to_keep : -padding_size, :]
485
625
  else:
486
- output_logits = output_logits[:, :query_length, :]
487
- # index copy for masked output_logits
488
- if attention_mask is not None:
489
- new_output_logits = torch.full(
490
- (1, attention_mask.shape[-1], output_logits.shape[-1]),
491
- fill_value=1e-10,
492
- dtype=output_logits.dtype,
493
- )
494
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
495
- new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
626
+ output_logits = output_logits[:, :-padding_size, :]
496
627
 
497
- output_logits = new_output_logits
628
+ all_hidden_states = None
629
+ if self.rbln_config.output_hidden_states:
630
+ all_hidden_states = [
631
+ output_hidden_state[:, :-padding_size, :] for output_hidden_state in output_hidden_states
632
+ ]
633
+ all_hidden_states = tuple(all_hidden_states)
498
634
 
499
635
  # Update decoder attention mask with processed KV-cache length from prefill phase
500
636
  if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
501
- self.dec_attn_mask[batch_idx].fill_(0)
502
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
637
+ if self.rbln_config.use_position_ids:
638
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
639
+ else:
640
+ self.dec_attn_mask[batch_idx].fill_(0)
641
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
503
642
 
504
- return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
643
+ return RBLNDecoderOnlyOutput(
644
+ logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
645
+ )
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
16
16
 
17
17
  import torch
18
+ from transformers import GenerationConfig
18
19
  from transformers.generation.utils import GenerationMixin
20
+ from transformers.modeling_outputs import ModelOutput
19
21
 
20
22
 
21
23
  if TYPE_CHECKING:
@@ -91,20 +93,26 @@ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
91
93
  self,
92
94
  input_ids: torch.LongTensor,
93
95
  attention_mask: Optional[torch.LongTensor] = None,
94
- max_length: Optional[int] = None,
96
+ generation_config: Optional[GenerationConfig] = None,
95
97
  **kwargs,
96
- ):
98
+ ) -> Union[ModelOutput, torch.LongTensor]:
97
99
  """
98
100
  The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
101
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
99
102
 
100
103
  Args:
101
- input_ids: The input ids to the model.
102
- attention_mask: The attention mask to the model.
103
- max_length: The maximum length of the generated text.
104
- kwargs: Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
104
+ input_ids (torch.LongTensor): The input ids to the model.
105
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
106
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
107
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
108
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
109
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
110
+
111
+ Returns:
112
+ A ModelOutput (if return_dict_in_generate=True or when config.return_dict_in_generate=True) or a torch.LongTensor.
105
113
  """
106
- if max_length is not None:
107
- kwargs["max_length"] = max_length
114
+ if generation_config is not None:
115
+ kwargs["generation_config"] = generation_config
108
116
  if attention_mask is not None:
109
117
  kwargs["attention_mask"] = attention_mask
110
118
 
@@ -142,7 +142,7 @@ class LoRALinear(nn.Module):
142
142
  padded_lora_a = []
143
143
  padded_lora_b = []
144
144
 
145
- for i, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
145
+ for _, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
146
146
  current_rank = lora_a.shape[0]
147
147
  if current_rank < max_rank:
148
148
  # Pad with zeros