optimum-rbln 0.7.5a0__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. optimum/rbln/__init__.py +30 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +9 -4
  4. optimum/rbln/modeling.py +7 -5
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/sliding_window_attn.py +111 -0
  9. optimum/rbln/transformers/__init__.py +32 -3
  10. optimum/rbln/transformers/models/__init__.py +37 -0
  11. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  12. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  13. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  14. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
  15. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
  16. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +12 -6
  17. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +189 -90
  18. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +186 -95
  19. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  20. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  21. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  22. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
  23. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
  24. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
  25. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  26. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
  27. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  28. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  29. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  30. optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
  31. optimum/rbln/transformers/models/opt/modeling_opt.py +80 -0
  32. optimum/rbln/transformers/models/opt/opt_architecture.py +77 -0
  33. optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
  34. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -11
  35. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
  36. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
  37. optimum/rbln/transformers/models/siglip/__init__.py +20 -0
  38. optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
  39. optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
  40. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
  41. optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
  42. optimum/rbln/utils/submodule.py +13 -1
  43. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
  44. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +46 -31
  45. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
  46. {optimum_rbln-0.7.5a0.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,7 @@ from ....configuration_utils import RBLNCompileConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import QuantizationManager
33
+ from ...utils.rbln_quantization import prepare_model_for_quantization
34
34
  from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
35
  from .decoderonly_architecture import (
36
36
  DecoderOnlyWrapper,
@@ -59,6 +59,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
59
59
  kvcache_block_size: int,
60
60
  use_attention_mask: bool,
61
61
  attn_impl: str,
62
+ use_position_ids: bool,
62
63
  **kwargs: Any,
63
64
  ) -> None:
64
65
  super().__init__(runtime, **kwargs)
@@ -72,6 +73,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
72
73
  self.dec_attn_mask = dec_attn_mask
73
74
  self.block_tables = block_tables
74
75
  self.free_block_pool = free_block_pool
76
+ self.use_position_ids = use_position_ids
75
77
 
76
78
  self.kvcache_block_size = kvcache_block_size
77
79
  self.empty_block = -1
@@ -164,6 +166,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
164
166
  batch_idx: Optional[int] = None,
165
167
  block_tables: Optional[torch.Tensor] = None,
166
168
  position_embed: Optional[torch.Tensor] = None,
169
+ position_ids: Optional[torch.Tensor] = None,
170
+ token_type_ids: Optional[torch.Tensor] = None,
171
+ local_block_tables: Optional[torch.Tensor] = None,
167
172
  ):
168
173
  if input_ids is None and inputs_embeds is None:
169
174
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -189,10 +194,19 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
189
194
  is_external_block_tables,
190
195
  attention_mask=attention_mask,
191
196
  position_embed=position_embed,
197
+ position_ids=position_ids,
198
+ local_block_tables=local_block_tables,
192
199
  )
193
200
  else:
194
201
  return self.prefill_forward(
195
- inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
202
+ inputs,
203
+ cache_position,
204
+ attention_mask,
205
+ batch_idx,
206
+ block_tables,
207
+ position_embed=position_embed,
208
+ token_type_ids=token_type_ids,
209
+ local_block_tables=local_block_tables,
196
210
  )
197
211
 
198
212
  def decode_forward(
@@ -203,6 +217,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
203
217
  is_external_block_tables: bool = None,
204
218
  attention_mask: Optional[torch.Tensor] = None,
205
219
  position_embed: Optional[torch.Tensor] = None,
220
+ position_ids: Optional[torch.Tensor] = None,
221
+ local_block_tables: Optional[torch.Tensor] = None,
206
222
  ) -> torch.FloatTensor:
207
223
  batch_size = inputs.shape[0]
208
224
  if batch_size != self.batch_size:
@@ -232,35 +248,32 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
232
248
  if self.batch_size < block_tables.shape[0]:
233
249
  block_tables = block_tables[: self.batch_size]
234
250
 
235
- if self.batch_size < attention_mask.shape[0]:
251
+ if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
236
252
  attention_mask = attention_mask[: self.batch_size]
237
253
 
238
254
  logits = super().forward(
239
255
  inputs,
240
256
  cache_position,
241
- attention_mask if self.use_attention_mask else None,
242
257
  block_tables,
243
258
  position_embed,
259
+ attention_mask if self.use_attention_mask else None,
260
+ position_ids if self.use_position_ids else None,
244
261
  )
245
262
 
246
- return logits
263
+ return RBLNDecoderOnlyOutput(logits=logits)
247
264
 
248
- def prefill_forward(
265
+ def _prepare_prefill_inputs(
249
266
  self,
250
267
  inputs: torch.Tensor,
251
- cache_position: torch.Tensor = None,
268
+ cache_position: torch.Tensor,
252
269
  attention_mask: Optional[torch.Tensor] = None,
253
- batch_idx: int = None,
254
- block_tables: torch.Tensor = None,
255
- is_external_block_tables: bool = None,
256
270
  position_embed: Optional[torch.Tensor] = None,
257
- ) -> torch.FloatTensor:
271
+ local_block_tables: Optional[torch.Tensor] = None,
272
+ token_type_ids: Optional[torch.Tensor] = None,
273
+ ):
258
274
  """
259
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
260
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
261
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
275
+ Prepare inputs for prefill phase.
262
276
  """
263
-
264
277
  # Handle continuous batching in a compiled graph by extracting valid inputs
265
278
  # If an attention mask is provided, select only the valid (non-masked) inputs
266
279
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
@@ -276,8 +289,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
276
289
  )
277
290
 
278
291
  # Initialize attention mask for chunked processing
279
- if self.use_attention_mask:
280
- chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
292
+ chunked_attention_mask = (
293
+ torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
294
+ if self.use_attention_mask
295
+ else None
296
+ )
281
297
 
282
298
  # Buffer for storing output logits
283
299
  out_buffers = [
@@ -288,40 +304,88 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
288
304
  )
289
305
  ]
290
306
 
291
- # Process input in chunks of size `prefill_chunk_size`
292
- for step in range(0, query_length, self.prefill_chunk_size):
293
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
294
- if (step + self.prefill_chunk_size) > query_length:
295
- padding_size = step + self.prefill_chunk_size - query_length
296
- # inputs_embeds
297
- if inputs.dim() == 3:
298
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
299
- # inputs_ids
300
- else:
301
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
307
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
308
+ if query_length % self.prefill_chunk_size != 0:
309
+ padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
310
+ # inputs_embeds
311
+ if inputs.dim() == 3:
312
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
313
+ # inputs_ids
314
+ else:
315
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
302
316
 
303
- cache_position = torch.cat(
304
- [
305
- cache_position,
306
- torch.arange(
307
- query_length,
308
- step + self.prefill_chunk_size,
309
- dtype=torch.int32,
310
- ).unsqueeze(0),
311
- ],
312
- dim=-1,
313
- )
317
+ cache_position = torch.cat(
318
+ [
319
+ cache_position,
320
+ torch.arange(
321
+ query_length,
322
+ query_length + padding_size,
323
+ dtype=torch.int32,
324
+ ).unsqueeze(0),
325
+ ],
326
+ dim=-1,
327
+ )
328
+
329
+ if position_embed is not None:
330
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
331
+
332
+ # Overwrite position_ids and padded_cache_lengths
333
+ position_ids = None
334
+ padded_cache_lengths = 0
335
+
336
+ return (
337
+ inputs,
338
+ cache_position,
339
+ chunked_attention_mask,
340
+ out_buffers,
341
+ position_ids,
342
+ position_embed,
343
+ padded_cache_lengths,
344
+ query_length,
345
+ )
314
346
 
315
- if position_embed is not None:
316
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
347
+ def prefill_forward(
348
+ self,
349
+ inputs: torch.Tensor,
350
+ cache_position: torch.Tensor = None,
351
+ attention_mask: Optional[torch.Tensor] = None,
352
+ batch_idx: int = None,
353
+ block_tables: torch.Tensor = None,
354
+ is_external_block_tables: bool = None,
355
+ position_embed: Optional[torch.Tensor] = None,
356
+ local_block_tables: Optional[torch.Tensor] = None,
357
+ token_type_ids: Optional[torch.Tensor] = None,
358
+ ) -> torch.FloatTensor:
359
+ """
360
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
361
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
362
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
363
+ """
364
+ (
365
+ inputs,
366
+ cache_position,
367
+ chunked_attention_mask,
368
+ out_buffers,
369
+ position_ids,
370
+ position_embed,
371
+ padded_cache_lengths,
372
+ query_length,
373
+ ) = self._prepare_prefill_inputs(
374
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
375
+ )
317
376
 
377
+ # Process input in chunks of size `prefill_chunk_size`
378
+ for step in range(0, query_length, self.prefill_chunk_size):
318
379
  # Extract the current chunk of inputs and cache positions
319
380
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
320
381
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
382
+ position_ids_chunk = (
383
+ position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
384
+ )
321
385
  if position_embed is not None:
322
386
  position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
323
387
 
324
- if self.use_attention_mask:
388
+ if self.use_attention_mask and not self.use_position_ids:
325
389
  # Update attention mask to ensure proper causal behavior
326
390
  if step >= self.prefill_chunk_size:
327
391
  chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
@@ -334,10 +398,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
334
398
  logits = super().forward(
335
399
  input_chunk,
336
400
  cache_pos_chunk,
337
- chunked_attention_mask if self.use_attention_mask else None,
338
- query_position,
339
401
  block_tables,
340
402
  position_embed_chunk if position_embed is not None else None,
403
+ query_position,
404
+ chunked_attention_mask if self.use_attention_mask else None,
405
+ position_ids_chunk if self.use_position_ids else None,
341
406
  out=out_buffers,
342
407
  )
343
408
 
@@ -346,13 +411,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
346
411
  self.dec_attn_mask[batch_idx].fill_(0)
347
412
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
348
413
 
349
- return logits
414
+ return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
350
415
 
351
416
 
352
417
  @dataclass
353
418
  class RBLNDecoderOnlyOutput(ModelOutput):
354
419
  logits: torch.FloatTensor = None
355
420
  generate_idx: torch.Tensor = None
421
+ padded_cache_lengths: int = None
356
422
 
357
423
 
358
424
  class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
@@ -386,12 +452,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
386
452
  if self.rbln_config.use_inputs_embeds:
387
453
  main_input_name = "inputs_embeds"
388
454
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
389
- with no_init_weights():
390
- self.embed_tokens = torch.nn.Embedding(
391
- self.config.vocab_size,
392
- self.config.hidden_size,
393
- self.config.pad_token_id,
394
- )
455
+ self.embed_tokens = self._create_embedding_layer()
395
456
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
396
457
  else:
397
458
  self.embed_tokens = None
@@ -422,7 +483,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
422
483
  max_seq_len=self.rbln_config.max_seq_len,
423
484
  use_attention_mask=self.rbln_config.use_attention_mask,
424
485
  attn_impl=self.rbln_config.attn_impl,
486
+ use_position_ids=self.rbln_config.use_position_ids,
425
487
  )
488
+
426
489
  self.decoders = {}
427
490
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
428
491
  self.decoders[batch_size] = RBLNRuntimeModel(
@@ -437,6 +500,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
437
500
  kvcache_block_size=self.rbln_config.kvcache_block_size,
438
501
  use_attention_mask=self.rbln_config.use_attention_mask,
439
502
  attn_impl=self.rbln_config.attn_impl,
503
+ use_position_ids=self.rbln_config.use_position_ids,
440
504
  )
441
505
 
442
506
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -459,6 +523,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
459
523
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
460
524
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
461
525
 
526
+ def _create_embedding_layer(self):
527
+ with no_init_weights():
528
+ embed_tokens = torch.nn.Embedding(
529
+ self.config.vocab_size,
530
+ self.config.hidden_size,
531
+ self.config.pad_token_id,
532
+ )
533
+ return embed_tokens
534
+
462
535
  def get_input_embeddings(self):
463
536
  return self.embed_tokens
464
537
 
@@ -482,8 +555,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
482
555
  trust_remote_code: bool = False,
483
556
  **kwargs,
484
557
  ):
485
- from ...utils.rbln_quantization import prepare_model_for_quantization
486
-
487
558
  kwargs = cls.update_kwargs(kwargs)
488
559
 
489
560
  if config is None:
@@ -500,8 +571,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
500
571
  with no_init_weights():
501
572
  model = AutoModelForCausalLM.from_config(config)
502
573
 
503
- prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
504
-
574
+ model = prepare_model_for_quantization(
575
+ model,
576
+ model_id,
577
+ kwargs.get("num_hidden_layers"),
578
+ use_auth_token=use_auth_token,
579
+ revision=revision,
580
+ cache_dir=cache_dir,
581
+ force_download=force_download,
582
+ local_files_only=local_files_only,
583
+ )
505
584
  return model
506
585
 
507
586
  def __getattr__(self, __name: str) -> Any:
@@ -528,11 +607,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
528
607
  def get_pytorch_model(
529
608
  cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
530
609
  ) -> "PreTrainedModel":
531
- if (
532
- rbln_config is not None
533
- and "format" in rbln_config.quantization
534
- and rbln_config.quantization["format"] == "rbln"
535
- ):
610
+ if rbln_config and rbln_config.quantization:
536
611
  model = cls.get_quantized_model(*args, **kwargs)
537
612
  else:
538
613
  model = super().get_pytorch_model(*args, **kwargs)
@@ -548,6 +623,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
548
623
  "kvcache_block_size": rbln_config.kvcache_block_size,
549
624
  "use_rotary_emb": cls._use_rotary_emb,
550
625
  "use_attention_mask": rbln_config.use_attention_mask,
626
+ "use_position_ids": rbln_config.use_position_ids,
627
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
551
628
  }
552
629
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
553
630
 
@@ -572,9 +649,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
572
649
  static_tensors[name] = tensor
573
650
  context.mark_static_address(tensor)
574
651
 
575
- @QuantizationManager.with_quantization_env
576
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, **kwargs):
652
+ def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
577
653
  try:
654
+ if quantization:
655
+ quantization.maybe_set_quantization_env()
578
656
  original_linear = torch.nn.functional.linear
579
657
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
580
658
  compiled_model = RBLNModel.compile(
@@ -586,14 +664,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
586
664
  return compiled_model
587
665
  finally:
588
666
  torch.nn.functional.linear = original_linear
667
+ if quantization:
668
+ quantization.maybe_reset_quantization_env()
589
669
 
590
670
  wrapped_model.phase = "prefill"
591
671
  compiled_prefill = compile_model(
592
- wrapped_model,
593
- prefill_compile_config,
594
- prefill_example_inputs,
595
- context,
596
- quantize_config=rbln_config.quantization,
672
+ wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
597
673
  )
598
674
 
599
675
  wrapped_model.phase = "decode"
@@ -601,11 +677,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
601
677
  for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
602
678
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
603
679
  compiled_decoder = compile_model(
604
- wrapped_model,
605
- dec_compile_config,
606
- dec_example_inputs,
607
- context,
608
- quantize_config=rbln_config.quantization,
680
+ wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
609
681
  )
610
682
  compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
611
683
 
@@ -763,6 +835,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
763
835
  query_length: int,
764
836
  use_inputs_embeds: bool,
765
837
  use_attention_mask: bool,
838
+ use_position_ids: bool,
766
839
  max_seq_len: int,
767
840
  kvcache_block_size: int,
768
841
  kvcache_num_blocks: int,
@@ -785,26 +858,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
785
858
  ),
786
859
  ]
787
860
 
788
- if use_attention_mask:
861
+ max_block_cnt = max_seq_len // kvcache_block_size
862
+
863
+ if query_length > 1:
864
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
865
+ else:
866
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
867
+
868
+ if query_length > 1:
789
869
  input_info.extend(
790
870
  [
791
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
871
+ ("query_position", [], "int16"),
792
872
  ]
793
873
  )
794
-
795
- if query_length > 1:
874
+ if use_attention_mask:
796
875
  input_info.extend(
797
876
  [
798
- ("query_position", [], "int16"),
877
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
799
878
  ]
800
879
  )
801
-
802
- max_block_cnt = max_seq_len // kvcache_block_size
803
-
804
- if query_length > 1:
805
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
806
- else:
807
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
880
+ if use_position_ids:
881
+ input_info.append(("position_ids", [batch_size, query_length], "int32"))
808
882
 
809
883
  input_info.extend(
810
884
  [
@@ -898,6 +972,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
898
972
  query_length=rbln_config.prefill_chunk_size,
899
973
  use_inputs_embeds=rbln_config.use_inputs_embeds,
900
974
  use_attention_mask=rbln_config.use_attention_mask,
975
+ use_position_ids=rbln_config.use_position_ids,
901
976
  max_seq_len=rbln_config.max_seq_len,
902
977
  kvcache_block_size=rbln_config.kvcache_block_size,
903
978
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
@@ -916,6 +991,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
916
991
  query_length=1,
917
992
  use_inputs_embeds=rbln_config.use_inputs_embeds,
918
993
  use_attention_mask=rbln_config.use_attention_mask,
994
+ use_position_ids=rbln_config.use_position_ids,
919
995
  max_seq_len=rbln_config.max_seq_len,
920
996
  kvcache_block_size=rbln_config.kvcache_block_size,
921
997
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
@@ -977,6 +1053,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
977
1053
  generate_idx: Optional[torch.Tensor] = None,
978
1054
  attention_mask: Optional[torch.LongTensor] = None,
979
1055
  inputs_embeds: Optional[torch.Tensor] = None,
1056
+ padded_cache_lengths: Optional[torch.Tensor] = None,
980
1057
  **kwargs,
981
1058
  ):
982
1059
  model_inputs = {}
@@ -984,13 +1061,17 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
984
1061
 
985
1062
  if is_prefill_phase:
986
1063
  generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1064
+ padded_cache_lengths = torch.zeros_like(generate_idx)
987
1065
  cache_position = None
1066
+ position_ids = None
988
1067
  else:
989
1068
  if inputs_embeds is not None:
990
- raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
1069
+ # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
1070
+ inputs_embeds = None
991
1071
 
992
1072
  input_ids = input_ids[:, -1:]
993
- cache_position = generate_idx
1073
+ position_ids = generate_idx
1074
+ cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
994
1075
  generate_idx = generate_idx + 1
995
1076
  model_inputs.update({"input_ids": input_ids})
996
1077
 
@@ -1009,6 +1090,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1009
1090
  "attention_mask": attention_mask,
1010
1091
  "cache_position": cache_position,
1011
1092
  "generate_idx": generate_idx,
1093
+ "position_ids": position_ids,
1094
+ "padded_cache_lengths": padded_cache_lengths,
1012
1095
  }
1013
1096
  )
1014
1097
 
@@ -1022,6 +1105,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1022
1105
  ) -> Dict[str, Any]:
1023
1106
  # update generate_idx
1024
1107
  model_kwargs["generate_idx"] = outputs.generate_idx
1108
+ model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
1025
1109
 
1026
1110
  return model_kwargs
1027
1111
 
@@ -1032,6 +1116,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1032
1116
  cache_position: Optional[torch.Tensor] = None,
1033
1117
  attention_mask: Optional[torch.LongTensor] = None,
1034
1118
  generate_idx: Optional[torch.Tensor] = None,
1119
+ padded_cache_lengths: Optional[torch.Tensor] = None,
1120
+ position_ids: Optional[torch.Tensor] = None,
1121
+ token_type_ids: Optional[torch.Tensor] = None,
1122
+ return_dict: Optional[torch.Tensor] = None,
1035
1123
  **kwargs,
1036
1124
  ) -> Tuple[torch.FloatTensor]:
1037
1125
  """
@@ -1045,18 +1133,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1045
1133
  logits = []
1046
1134
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1047
1135
  batch_size = inputs.shape[0]
1048
-
1049
1136
  for b_idx in range(batch_size):
1050
1137
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
1051
- logit = self.prefill_decoder(
1138
+ output = self.prefill_decoder(
1052
1139
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
1053
1140
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
1054
1141
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
1055
1142
  cache_position=cache_position,
1056
1143
  batch_idx=b_idx,
1144
+ token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
1057
1145
  )
1058
- logits.append(logit)
1059
-
1146
+ padded_cache_lengths[b_idx] += output.padded_cache_lengths
1147
+ logits.append(output.logits)
1060
1148
  logits = torch.cat(logits, dim=0)
1061
1149
  # Decoder
1062
1150
  else:
@@ -1072,9 +1160,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1072
1160
  input_ids=input_ids,
1073
1161
  inputs_embeds=inputs_embeds,
1074
1162
  cache_position=cache_position,
1075
- )
1163
+ position_ids=position_ids if self.rbln_config.use_position_ids else None,
1164
+ ).logits
1076
1165
 
1077
- return RBLNDecoderOnlyOutput(
1078
- logits=logits,
1079
- generate_idx=generate_idx,
1080
- )
1166
+ if not return_dict:
1167
+ return logits, generate_idx, padded_cache_lengths
1168
+ else:
1169
+ return RBLNDecoderOnlyOutput(
1170
+ logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1171
+ )
@@ -41,7 +41,10 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
43
  new_self_attn = ExaoneAttention(
44
- layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
44
+ layer.attn.attention,
45
+ self.use_attention_mask,
46
+ kvcache_block_size=self.kvcache_block_size,
47
+ use_position_ids=self.use_position_ids,
45
48
  )
46
49
  elif self.attn_impl == "flash_attn":
47
50
  new_self_attn = ExaoneFlashAttention(
@@ -49,6 +52,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
49
52
  kvcache_partition_len=self.kvcache_partition_len,
50
53
  use_attention_mask=self.use_attention_mask,
51
54
  kvcache_block_size=self.kvcache_block_size,
55
+ use_position_ids=self.use_position_ids,
52
56
  )
53
57
  else:
54
58
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -34,7 +34,10 @@ class GemmaWrapper(DecoderOnlyWrapper):
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
36
  new_self_attn = DecoderOnlyAttention(
37
- layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
37
+ layer.self_attn,
38
+ self.use_attention_mask,
39
+ kvcache_block_size=self.kvcache_block_size,
40
+ use_position_ids=self.use_position_ids,
38
41
  )
39
42
  elif self.attn_impl == "flash_attn":
40
43
  new_self_attn = DecoderOnlyFlashAttention(
@@ -42,6 +45,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
42
45
  kvcache_partition_len=self.kvcache_partition_len,
43
46
  use_attention_mask=self.use_attention_mask,
44
47
  kvcache_block_size=self.kvcache_block_size,
48
+ use_position_ids=self.use_position_ids,
45
49
  )
46
50
  else:
47
51
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig, RBLNGemma3ForConditionalGenerationConfig
16
+ from .modeling_gemma3 import RBLNGemma3ForCausalLM, RBLNGemma3ForConditionalGeneration