optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (120) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  23. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  24. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  25. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  31. optimum/rbln/modeling.py +4 -5
  32. optimum/rbln/modeling_base.py +18 -14
  33. optimum/rbln/ops/kv_cache_update.py +5 -0
  34. optimum/rbln/ops/linear.py +7 -0
  35. optimum/rbln/transformers/__init__.py +60 -0
  36. optimum/rbln/transformers/configuration_generic.py +4 -4
  37. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  38. optimum/rbln/transformers/modeling_generic.py +1 -4
  39. optimum/rbln/transformers/models/__init__.py +45 -30
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  41. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  42. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  43. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  44. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  45. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  46. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  47. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  48. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  51. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  52. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  53. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  54. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  55. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  56. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  57. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  58. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  59. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  60. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  61. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  62. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  63. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  64. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  65. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  66. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  67. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  68. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  69. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  75. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  76. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  77. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  78. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  79. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  80. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  81. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  82. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  83. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  84. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  85. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  86. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  87. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  91. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  92. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  93. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  94. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  97. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  101. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  102. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  103. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  104. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  105. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  106. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  107. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  108. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  110. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  111. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  112. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  113. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  114. optimum/rbln/utils/depreacate_utils.py +16 -0
  115. optimum/rbln/utils/hub.py +8 -47
  116. optimum/rbln/utils/runtime_utils.py +31 -5
  117. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  118. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
  119. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  120. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -19,33 +19,28 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
19
19
  import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
22
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
23
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
24
  from transformers.modeling_utils import no_init_weights
30
25
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
26
 
32
27
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
28
  from ....modeling import RBLNModel
34
- from ....utils.logging import get_logger
35
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
29
+ from ..decoderonly.modeling_decoderonly import (
30
+ RBLNDecoderOnlyForCausalLMOutput,
31
+ RBLNDecoderOnlyModelForCausalLM,
32
+ RBLNRuntimeModel,
33
+ )
36
34
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
37
35
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
38
36
 
39
37
 
40
- logger = get_logger()
41
-
42
-
43
38
  if TYPE_CHECKING:
44
39
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
40
 
46
41
 
47
42
  @dataclass
48
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
43
+ class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
49
44
  attention_mask: Optional[torch.Tensor] = None
50
45
 
51
46
 
@@ -201,7 +196,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
201
196
 
202
197
  def _update_model_kwargs_for_generation(
203
198
  self,
204
- outputs: RBLNDecoderOnlyOutput,
199
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
205
200
  model_kwargs: Dict[str, Any],
206
201
  **kwargs,
207
202
  ) -> Dict[str, Any]:
@@ -258,19 +253,47 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
258
253
 
259
254
  return inputs_embeds
260
255
 
256
+ def get_padded_cache_position(
257
+ self,
258
+ cache_position: torch.Tensor, # shape: [1, seq_len]
259
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
260
+ ) -> torch.Tensor:
261
+ seq_len = cache_position[0][-1].item() + 1
262
+
263
+ # Find image start positions
264
+ image_starts = [
265
+ s
266
+ for s in torch.where(token_type_ids == 1)[1]
267
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
268
+ ]
269
+
270
+ # Initialize padded tensors
271
+ padded_input_len = seq_len
272
+ for image_start in image_starts:
273
+ pad_needed = (
274
+ self.rbln_config.image_prefill_chunk_size
275
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
276
+ ) % self.rbln_config.image_prefill_chunk_size
277
+ padded_input_len += pad_needed
278
+
279
+ return torch.cat(
280
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
281
+ dim=1,
282
+ )
283
+
261
284
  def forward(
262
285
  self,
263
286
  input_ids: torch.LongTensor = None,
287
+ attention_mask: torch.Tensor = None,
288
+ token_type_ids: torch.Tensor = None,
264
289
  pixel_values: torch.FloatTensor = None,
265
- attention_mask: Optional[torch.Tensor] = None,
266
290
  cache_position: Optional[torch.LongTensor] = None,
267
291
  inputs_embeds: Optional[torch.FloatTensor] = None,
268
292
  generate_idx: Optional[torch.Tensor] = None,
269
293
  padded_cache_lengths: Optional[torch.Tensor] = None,
270
294
  position_ids: Optional[torch.Tensor] = None,
271
- token_type_ids: Optional[torch.Tensor] = None,
272
295
  **lm_kwargs: Dict[str, Any],
273
- ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
296
+ ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
274
297
  # prefill
275
298
  if cache_position is None:
276
299
  logits = []
@@ -279,12 +302,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
279
302
 
280
303
  for b_idx in range(batch_size):
281
304
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
305
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
306
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
307
+
282
308
  output = self.language_model.prefill_decoder(
283
309
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
284
310
  attention_mask=attention_mask[b_idx],
285
311
  cache_position=cache_position,
286
312
  batch_idx=b_idx,
287
- token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
313
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
288
314
  )
289
315
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
290
316
  logits.append(output.logits)
@@ -308,7 +334,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
308
334
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
309
335
  ).logits
310
336
 
311
- return RBLNDecoderOnlyOutput(
337
+ return RBLNDecoderOnlyForCausalLMOutput(
312
338
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
313
339
  )
314
340
 
@@ -320,194 +346,30 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
320
346
  self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
321
347
  self.decode = self.runtime if self.phase == "decode" else None
322
348
 
323
- def pad_for_chunked_images(
324
- self,
325
- inputs: torch.Tensor,
326
- attention_mask: torch.Tensor,
327
- position_ids: torch.Tensor,
328
- token_type_ids: Optional[torch.Tensor] = None,
329
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]:
330
- """
331
- Pads inputs, attention_mask, and position_ids so image token groups (256 tokens with token_type_ids == 1)
332
- start at multiples of prefill_chunk_size (256). Returns padded tensors and total padded length.
333
-
334
- Args:
335
- inputs: (1, seq_len, hidden_size) tensor.
336
- attention_mask: (1, seq_len) tensor, 1 for valid, 0 for masked.
337
- position_ids: (1, seq_len) tensor for RoPE.
338
- token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
339
-
340
- Returns:
341
- (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
342
- """
343
-
344
- if token_type_ids is None:
345
- return inputs, attention_mask, position_ids, 0, torch.zeros(inputs.shape[:2], dtype=torch.long)
346
-
347
- seq_len = inputs.shape[1]
348
-
349
- # Find image start positions
350
- image_starts = [
351
- s
352
- for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
353
- if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
354
- ]
355
-
356
- # Initialize padded tensors
357
- padded_input_len = seq_len
358
- for image_start in image_starts:
359
- pad_needed = (
360
- self.rbln_config.prefill_chunk_size
361
- - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
362
- ) % self.rbln_config.prefill_chunk_size
363
- padded_input_len += pad_needed
364
- total_padding = padded_input_len - seq_len
365
-
366
- if inputs.dim() == 3:
367
- inputs_padded = torch.zeros(1, padded_input_len, inputs.shape[2], dtype=inputs.dtype)
368
- else:
369
- inputs_padded = torch.zeros(1, padded_input_len, dtype=inputs.dtype)
370
- attention_mask_padded = torch.zeros(1, padded_input_len, dtype=attention_mask.dtype)
371
- position_ids_padded = torch.zeros(1, padded_input_len, dtype=position_ids.dtype)
372
- token_type_ids_padded = torch.zeros(1, padded_input_len, dtype=token_type_ids.dtype)
373
-
374
- # Fill padded tensors
375
- dest_pos = 0
376
- src_pos = 0
377
- last_pos_id = -1
378
- for image_start in image_starts + [seq_len]:
379
- # Text segment
380
- if src_pos < image_start:
381
- length = image_start - src_pos
382
- inputs_padded[:, dest_pos : dest_pos + length] = inputs[:, src_pos:image_start]
383
- attention_mask_padded[:, dest_pos : dest_pos + length] = attention_mask[:, src_pos:image_start]
384
- position_ids_padded[:, dest_pos : dest_pos + length] = position_ids[:, src_pos:image_start]
385
- token_type_ids_padded[:, dest_pos : dest_pos + length] = token_type_ids[:, src_pos:image_start]
386
- dest_pos += length
387
- last_pos_id = position_ids[0, image_start - 1].item()
388
- src_pos = image_start
389
-
390
- # Padding
391
- pad_needed = (
392
- self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
393
- ) % self.rbln_config.prefill_chunk_size
394
- if pad_needed and dest_pos < padded_input_len:
395
- position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
- last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
397
- ).unsqueeze(0)
398
- dest_pos += pad_needed
399
-
400
- # Image segment
401
- if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
- ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
- ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
- ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
- ]
414
- dest_pos += self.rbln_config.prefill_chunk_size
415
- src_pos += self.rbln_config.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
-
418
- return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
-
420
- def _prepare_prefill_inputs(
421
- self,
422
- inputs: torch.Tensor,
423
- cache_position: torch.Tensor,
424
- attention_mask: Optional[torch.Tensor] = None,
425
- position_embed: Optional[torch.Tensor] = None,
426
- token_type_ids: Optional[torch.Tensor] = None,
427
- ):
428
- """
429
- Prepare inputs for prefill phase.
430
- """
431
- # Handle continuous batching in a compiled graph by extracting valid inputs
432
- # If an attention mask is provided, select only the valid (non-masked) inputs
433
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
434
- token_type_ids = (
435
- token_type_ids[:, attention_mask.bool()]
436
- if attention_mask is not None and token_type_ids is not None
437
- else token_type_ids
438
- )
439
-
440
- if position_embed is not None:
441
- position_embed = (
442
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
443
- )
444
-
445
- seq_len = inputs.shape[1]
446
- # Initialize attention mask for chunked processing
447
- if self.rbln_config.use_attention_mask:
448
- chunked_attention_mask = (
449
- torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.rbln_config.use_position_ids
451
- else torch.zeros(
452
- 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
453
- )
454
- )
455
- else:
456
- chunked_attention_mask = None
457
-
458
- # Buffer for storing output logits
459
- out_buffers = [
460
- torch.empty(
461
- size=self.output_size,
462
- dtype=torch.float32,
463
- device="cpu",
464
- )
465
- ]
466
-
467
- inputs, chunked_attention_mask, position_ids, padded_cache_lengths, token_type_ids_padded = (
468
- self.pad_for_chunked_images(inputs, chunked_attention_mask, cache_position, token_type_ids)
469
- )
470
-
471
- query_length = inputs.shape[1]
472
- if query_length > self.rbln_config.max_seq_len:
473
- raise ValueError(
474
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
475
- )
476
-
477
- # Align attention_mask to compiled shape
478
- if self.rbln_config.use_position_ids:
479
- chunked_attention_mask = torch.nn.functional.pad(
480
- chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
481
- )
482
-
483
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
484
- padding_size = 0
485
- if query_length % self.rbln_config.prefill_chunk_size != 0:
486
- padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
487
- # inputs_embeds
488
- if inputs.dim() == 3:
489
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
490
- # inputs_ids
491
- else:
492
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
493
-
494
- position_ids = torch.cat(
495
- [
496
- position_ids,
497
- torch.arange(
498
- query_length,
499
- query_length + padding_size,
500
- dtype=torch.int32,
501
- ).unsqueeze(0),
502
- ],
503
- dim=-1,
504
- )
505
- token_type_ids_padded = torch.nn.functional.pad(token_type_ids_padded, (0, padding_size))
349
+ def _prepare_prefill_inputs(self, *args, **kwargs):
350
+ (
351
+ inputs,
352
+ cache_position,
353
+ chunked_attention_mask,
354
+ out_buffers,
355
+ position_ids,
356
+ position_embed,
357
+ padded_cache_lengths,
358
+ query_length,
359
+ token_type_ids,
360
+ ) = super()._prepare_prefill_inputs(*args, **kwargs)
506
361
 
507
- if position_embed is not None:
508
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
362
+ # chunked_attention_mask shape
363
+ chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
509
364
 
510
- cache_position = torch.arange(0, query_length + padding_size, dtype=torch.int32).unsqueeze(0)
365
+ # In case of Gemma3ForConditionalGeneration, the loop counter may not be a prefill_chunk_size,
366
+ # so we cannot guarantee that the last chunk starts at a position that is a multiple of prefill_chunk_size.
367
+ if self.rbln_config.use_image_prefill:
368
+ padding_size = self.rbln_config.image_prefill_chunk_size
369
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
370
+ cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
371
+ position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
372
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
511
373
 
512
374
  return (
513
375
  inputs,
@@ -518,7 +380,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
518
380
  position_embed,
519
381
  padded_cache_lengths,
520
382
  query_length,
521
- token_type_ids_padded,
383
+ token_type_ids,
522
384
  )
523
385
 
524
386
  def prefill_forward(
@@ -541,65 +403,69 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
541
403
  (
542
404
  inputs,
543
405
  cache_position,
544
- padded_attention_mask,
406
+ chunked_attention_mask,
545
407
  out_buffers,
546
408
  position_ids,
547
409
  position_embed,
548
410
  padded_cache_lengths,
549
411
  query_length,
550
- token_type_ids_padded,
412
+ token_type_ids,
551
413
  ) = self._prepare_prefill_inputs(
552
414
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
553
415
  )
554
- if not is_external_block_tables:
555
- local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
556
- self.dec_attn_mask[batch_idx : batch_idx + 1] = padded_attention_mask[:1]
557
416
 
558
- if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
559
- chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
417
+ step = 0
418
+ while step < query_length:
419
+ if self.rbln_config.use_image_prefill:
420
+ # Check if the prefill chunk is an image prefill
421
+ is_image_prefill = torch.all(
422
+ token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
423
+ )
424
+ # Check if the prefill chunk is a text prefill which have image_tokens in it.
425
+ is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
426
+ token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
427
+ )
428
+ else:
429
+ is_image_prefill, is_text_prefill_with_image_tokens = False, False
430
+
431
+ # Check if the prefill chunk is the last chunk
432
+ is_last_chunk = step + self.rbln_config.prefill_chunk_size >= query_length
560
433
 
561
- # Process input in chunks of size `prefill_chunk_size`
562
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
563
- # Extract the current chunk of inputs and cache positions
564
434
  input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
565
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
566
- position_ids_chunk = (
567
- position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
568
- if position_ids is not None
569
- else None
570
- )
571
-
572
- if self.rbln_config.use_attention_mask:
573
- if self.rbln_config.use_position_ids:
574
- chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = (
575
- padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
576
- )
577
-
578
- # Define query position
579
- query_position = (
580
- torch.sum(
581
- chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
582
- dim=-1,
583
- dtype=torch.int16,
584
- ).squeeze(0)
585
- - 1
435
+ cache_pos_chunk = (
436
+ cache_position[:, step : step + self.rbln_config.prefill_chunk_size] + padded_cache_lengths
586
437
  )
587
- if token_type_ids_padded[:, step] == 1:
588
- if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
589
- raise ValueError("All tokens of image_prefill should be the same image.")
590
- else:
591
- logits = self.image_prefill(
592
- input_chunk,
593
- cache_pos_chunk,
594
- block_tables,
595
- local_block_tables,
596
- query_position,
597
- chunked_attention_mask,
598
- position_ids_chunk,
599
- out=out_buffers,
600
- )
438
+ position_ids_chunk = position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
439
+
440
+ # if text_prefill end with image_tokens, we only treat the text part.
441
+ num_processed_tokens = self.rbln_config.prefill_chunk_size
442
+ current_padded_cache_lengths = 0
443
+ if is_text_prefill_with_image_tokens:
444
+ first_image_token_idx = torch.where(
445
+ token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
446
+ )[1][0]
447
+ num_processed_tokens = first_image_token_idx.item()
448
+ current_padded_cache_lengths = self.rbln_config.prefill_chunk_size - num_processed_tokens
449
+ if is_last_chunk:
450
+ num_processed_tokens = query_length - step
451
+
452
+ chunked_attention_mask[
453
+ :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
454
+ ] = 1
455
+ query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
456
+
457
+ if is_image_prefill:
458
+ logits = self.image_prefill(
459
+ input_chunk,
460
+ cache_pos_chunk,
461
+ block_tables,
462
+ local_block_tables,
463
+ query_position,
464
+ chunked_attention_mask,
465
+ position_ids_chunk,
466
+ out=out_buffers,
467
+ )
601
468
  else:
602
- # Forward pass for the current chunk
603
469
  logits = self.prefill(
604
470
  input_chunk,
605
471
  cache_pos_chunk,
@@ -611,6 +477,12 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
611
477
  out=out_buffers,
612
478
  )
613
479
 
480
+ padded_cache_lengths += current_padded_cache_lengths
481
+ step += num_processed_tokens
482
+
483
+ if not is_external_block_tables:
484
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
485
+
614
486
  return RBLNGemma3ForCausalLMOutput(
615
487
  logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
616
488
  )
@@ -666,7 +538,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
666
538
 
667
539
  logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
668
540
 
669
- return RBLNDecoderOnlyOutput(logits=logits)
541
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
670
542
 
671
543
 
672
544
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
@@ -701,9 +573,10 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
701
573
  dtype=torch.int16,
702
574
  ).fill_(-1)
703
575
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
576
+
704
577
  self.prefill_decoder = RBLNGemma3RuntimeModel(
705
578
  runtime=self.model[0],
706
- image_prefill=self.model[1],
579
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
707
580
  main_input_name=main_input_name,
708
581
  embed_tokens=self.embed_tokens,
709
582
  phase="prefill",
@@ -718,7 +591,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
718
591
  self.decoders = {}
719
592
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
720
593
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
721
- runtime=self.model[i + 2],
594
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
722
595
  main_input_name=main_input_name,
723
596
  embed_tokens=self.embed_tokens,
724
597
  phase="decode",
@@ -757,13 +630,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
757
630
 
758
631
  @classmethod
759
632
  def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
760
- if rbln_config.prefill_chunk_size is None:
761
- rbln_config.prefill_chunk_size = model.config.mm_tokens_per_image
633
+ if rbln_config.image_prefill_chunk_size is None:
634
+ rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
762
635
 
763
- if rbln_config.prefill_chunk_size != model.config.mm_tokens_per_image:
764
- logger.warning(
765
- f"Prefill chunk size is different from mm_tokens_per_image: {rbln_config.prefill_chunk_size} != {model.config.mm_tokens_per_image}"
636
+ if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
637
+ raise ValueError(
638
+ f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
766
639
  )
640
+
767
641
  return rbln_config
768
642
 
769
643
  @classmethod
@@ -777,15 +651,29 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
777
651
  # Update rbln_config with super class
778
652
  rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
779
653
 
780
- # Assume that prefill compile config is at index 0
781
- compile_cfgs = rbln_config.compile_cfgs
782
- image_prefill_compile_config = RBLNCompileConfig(
783
- compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
784
- )
785
- # Insert image_prefill compile config at index 1
786
- image_idx = 1
787
- compile_cfgs.insert(image_idx, image_prefill_compile_config)
788
- rbln_config.set_compile_cfgs(compile_cfgs)
654
+ if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
655
+ raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
656
+
657
+ if rbln_config.use_image_prefill:
658
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
659
+ raise NotImplementedError(
660
+ "Not implemented for different prefill chunk sizes between text and image prefill."
661
+ )
662
+
663
+ # Update image prefill compile config
664
+ img_prefill_input_info = cls.get_input_info(
665
+ batch_size=1,
666
+ query_length=rbln_config.image_prefill_chunk_size,
667
+ rbln_config=rbln_config,
668
+ model_config=model_config,
669
+ )
670
+ image_prefill_compile_config = RBLNCompileConfig(
671
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
672
+ )
673
+ # Insert image_prefill compile config at index 1
674
+ compile_cfgs = rbln_config.compile_cfgs
675
+ compile_cfgs.insert(1, image_prefill_compile_config)
676
+ rbln_config.set_compile_cfgs(compile_cfgs)
789
677
 
790
678
  return rbln_config
791
679
 
@@ -838,20 +726,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
838
726
  context,
839
727
  rbln_config.quantization,
840
728
  )
729
+ compiled_models = {"prefill": compiled_prefill}
841
730
 
842
- image_prefill_compile_config = rbln_compile_configs[1]
843
- wrapped_model.phase = "image_prefill"
844
- compiled_image_prefill = compile_model(
845
- wrapped_model,
846
- image_prefill_compile_config,
847
- prefill_example_inputs,
848
- context,
849
- rbln_config.quantization,
850
- )
731
+ if rbln_config.use_image_prefill:
732
+ image_prefill_compile_config = rbln_compile_configs[1]
733
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
734
+ fill=0, static_tensors=static_tensors
735
+ )
736
+ wrapped_model.phase = "image_prefill"
737
+ compiled_image_prefill = compile_model(
738
+ wrapped_model,
739
+ image_prefill_compile_config,
740
+ image_prefill_example_inputs,
741
+ context,
742
+ rbln_config.quantization,
743
+ )
744
+ compiled_models["image_prefill"] = compiled_image_prefill
851
745
 
852
- compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
853
746
  wrapped_model.phase = "decode"
854
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[2:]):
747
+ for batch_size, dec_compile_config in zip(
748
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
749
+ ):
855
750
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
856
751
  compiled_decoder = compile_model(
857
752
  wrapped_model,
@@ -872,32 +767,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
872
767
  ) -> List[rebel.Runtime]:
873
768
  expected_model_names = [
874
769
  "prefill",
875
- "image_prefill",
876
770
  *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
877
771
  ]
772
+ if rbln_config.use_image_prefill:
773
+ expected_model_names.insert(1, "image_prefill")
774
+
878
775
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
879
776
  cls._raise_missing_compiled_file_error(expected_model_names)
880
777
 
881
- return [
778
+ ret_val = [
882
779
  rebel.Runtime(
883
780
  compiled_models[0],
884
781
  tensor_type="pt",
885
782
  device=rbln_config.device_map["prefill"],
886
783
  activate_profiler=rbln_config.activate_profiler,
887
- ),
888
- rebel.Runtime(
889
- compiled_models[1],
890
- tensor_type="pt",
891
- device=rbln_config.device_map["image_prefill"],
892
- activate_profiler=rbln_config.activate_profiler,
893
- ),
894
- *[
784
+ timeout=rbln_config.timeout,
785
+ )
786
+ ]
787
+ if rbln_config.use_image_prefill:
788
+ ret_val.append(
895
789
  rebel.Runtime(
896
- compiled_models[i + 2],
790
+ compiled_models[1],
791
+ tensor_type="pt",
792
+ device=rbln_config.device_map["image_prefill"],
793
+ activate_profiler=rbln_config.activate_profiler,
794
+ timeout=rbln_config.timeout,
795
+ ),
796
+ )
797
+
798
+ ret_val.extend(
799
+ [
800
+ rebel.Runtime(
801
+ compiled_models[i + rbln_config.decoder_runtime_idx],
897
802
  tensor_type="pt",
898
803
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
899
804
  activate_profiler=rbln_config.activate_profiler,
805
+ timeout=rbln_config.timeout,
900
806
  )
901
807
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
902
- ],
903
- ]
808
+ ]
809
+ )
810
+
811
+ return ret_val
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig
16
- from .modeling_gpt2 import RBLNGPT2LMHeadModel
15
+ from .configuration_gpt2 import RBLNGPT2LMHeadModelConfig, RBLNGPT2ModelConfig
16
+ from .modeling_gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2Model