optimum-rbln 0.8.2a1__py3-none-any.whl → 0.8.2a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (34) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +16 -1
  4. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +3 -0
  5. optimum/rbln/diffusers/modeling_diffusers.py +3 -4
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  7. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  8. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  9. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  10. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +10 -2
  11. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -30
  12. optimum/rbln/modeling.py +2 -3
  13. optimum/rbln/modeling_base.py +17 -13
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/models/__init__.py +2 -0
  16. optimum/rbln/transformers/models/clip/configuration_clip.py +12 -1
  17. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  18. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +13 -1
  19. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -3
  20. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +107 -249
  21. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +18 -1
  22. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  23. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  24. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +377 -0
  25. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +275 -0
  26. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +2 -0
  27. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +2 -0
  28. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -0
  29. optimum/rbln/utils/hub.py +8 -47
  30. optimum/rbln/utils/runtime_utils.py +28 -2
  31. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/METADATA +1 -1
  32. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/RECORD +34 -30
  33. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/WHEEL +0 -0
  34. {optimum_rbln-0.8.2a1.dist-info → optimum_rbln-0.8.2a3.dist-info}/licenses/LICENSE +0 -0
@@ -31,15 +31,11 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
31
31
 
32
32
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
33
  from ....modeling import RBLNModel
34
- from ....utils.logging import get_logger
35
34
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
36
35
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
37
36
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
38
37
 
39
38
 
40
- logger = get_logger()
41
-
42
-
43
39
  if TYPE_CHECKING:
44
40
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
41
 
@@ -320,194 +316,28 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
320
316
  self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
321
317
  self.decode = self.runtime if self.phase == "decode" else None
322
318
 
323
- def pad_for_chunked_images(
324
- self,
325
- inputs: torch.Tensor,
326
- attention_mask: torch.Tensor,
327
- position_ids: torch.Tensor,
328
- token_type_ids: Optional[torch.Tensor] = None,
329
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]:
330
- """
331
- Pads inputs, attention_mask, and position_ids so image token groups (256 tokens with token_type_ids == 1)
332
- start at multiples of prefill_chunk_size (256). Returns padded tensors and total padded length.
333
-
334
- Args:
335
- inputs: (1, seq_len, hidden_size) tensor.
336
- attention_mask: (1, seq_len) tensor, 1 for valid, 0 for masked.
337
- position_ids: (1, seq_len) tensor for RoPE.
338
- token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
339
-
340
- Returns:
341
- (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
342
- """
343
-
344
- if token_type_ids is None:
345
- return inputs, attention_mask, position_ids, 0, torch.zeros(inputs.shape[:2], dtype=torch.long)
346
-
347
- seq_len = inputs.shape[1]
348
-
349
- # Find image start positions
350
- image_starts = [
351
- s
352
- for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
353
- if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
354
- ]
355
-
356
- # Initialize padded tensors
357
- padded_input_len = seq_len
358
- for image_start in image_starts:
359
- pad_needed = (
360
- self.rbln_config.prefill_chunk_size
361
- - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
362
- ) % self.rbln_config.prefill_chunk_size
363
- padded_input_len += pad_needed
364
- total_padding = padded_input_len - seq_len
365
-
366
- if inputs.dim() == 3:
367
- inputs_padded = torch.zeros(1, padded_input_len, inputs.shape[2], dtype=inputs.dtype)
368
- else:
369
- inputs_padded = torch.zeros(1, padded_input_len, dtype=inputs.dtype)
370
- attention_mask_padded = torch.zeros(1, padded_input_len, dtype=attention_mask.dtype)
371
- position_ids_padded = torch.zeros(1, padded_input_len, dtype=position_ids.dtype)
372
- token_type_ids_padded = torch.zeros(1, padded_input_len, dtype=token_type_ids.dtype)
373
-
374
- # Fill padded tensors
375
- dest_pos = 0
376
- src_pos = 0
377
- last_pos_id = -1
378
- for image_start in image_starts + [seq_len]:
379
- # Text segment
380
- if src_pos < image_start:
381
- length = image_start - src_pos
382
- inputs_padded[:, dest_pos : dest_pos + length] = inputs[:, src_pos:image_start]
383
- attention_mask_padded[:, dest_pos : dest_pos + length] = attention_mask[:, src_pos:image_start]
384
- position_ids_padded[:, dest_pos : dest_pos + length] = position_ids[:, src_pos:image_start]
385
- token_type_ids_padded[:, dest_pos : dest_pos + length] = token_type_ids[:, src_pos:image_start]
386
- dest_pos += length
387
- last_pos_id = position_ids[0, image_start - 1].item()
388
- src_pos = image_start
389
-
390
- # Padding
391
- pad_needed = (
392
- self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
393
- ) % self.rbln_config.prefill_chunk_size
394
- if pad_needed and dest_pos < padded_input_len:
395
- position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
- last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
397
- ).unsqueeze(0)
398
- dest_pos += pad_needed
399
-
400
- # Image segment
401
- if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
- ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
- ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
- ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
- ]
414
- dest_pos += self.rbln_config.prefill_chunk_size
415
- src_pos += self.rbln_config.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
-
418
- return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
-
420
- def _prepare_prefill_inputs(
421
- self,
422
- inputs: torch.Tensor,
423
- cache_position: torch.Tensor,
424
- attention_mask: Optional[torch.Tensor] = None,
425
- position_embed: Optional[torch.Tensor] = None,
426
- token_type_ids: Optional[torch.Tensor] = None,
427
- ):
428
- """
429
- Prepare inputs for prefill phase.
430
- """
431
- # Handle continuous batching in a compiled graph by extracting valid inputs
432
- # If an attention mask is provided, select only the valid (non-masked) inputs
433
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
434
- token_type_ids = (
435
- token_type_ids[:, attention_mask.bool()]
436
- if attention_mask is not None and token_type_ids is not None
437
- else token_type_ids
438
- )
439
-
440
- if position_embed is not None:
441
- position_embed = (
442
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
443
- )
444
-
445
- seq_len = inputs.shape[1]
446
- # Initialize attention mask for chunked processing
447
- if self.rbln_config.use_attention_mask:
448
- chunked_attention_mask = (
449
- torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.rbln_config.use_position_ids
451
- else torch.zeros(
452
- 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
453
- )
454
- )
455
- else:
456
- chunked_attention_mask = None
457
-
458
- # Buffer for storing output logits
459
- out_buffers = [
460
- torch.empty(
461
- size=self.output_size,
462
- dtype=torch.float32,
463
- device="cpu",
464
- )
465
- ]
466
-
467
- inputs, chunked_attention_mask, position_ids, padded_cache_lengths, token_type_ids_padded = (
468
- self.pad_for_chunked_images(inputs, chunked_attention_mask, cache_position, token_type_ids)
469
- )
470
-
471
- query_length = inputs.shape[1]
472
- if query_length > self.rbln_config.max_seq_len:
473
- raise ValueError(
474
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
475
- )
476
-
477
- # Align attention_mask to compiled shape
478
- if self.rbln_config.use_position_ids:
479
- chunked_attention_mask = torch.nn.functional.pad(
480
- chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
481
- )
482
-
483
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
484
- padding_size = 0
485
- if query_length % self.rbln_config.prefill_chunk_size != 0:
486
- padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
487
- # inputs_embeds
488
- if inputs.dim() == 3:
489
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
490
- # inputs_ids
491
- else:
492
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
493
-
494
- position_ids = torch.cat(
495
- [
496
- position_ids,
497
- torch.arange(
498
- query_length,
499
- query_length + padding_size,
500
- dtype=torch.int32,
501
- ).unsqueeze(0),
502
- ],
503
- dim=-1,
504
- )
505
- token_type_ids_padded = torch.nn.functional.pad(token_type_ids_padded, (0, padding_size))
319
+ def _prepare_prefill_inputs(self, *args, **kwargs):
320
+ (
321
+ inputs,
322
+ cache_position,
323
+ chunked_attention_mask,
324
+ out_buffers,
325
+ position_ids,
326
+ position_embed,
327
+ padded_cache_lengths,
328
+ query_length,
329
+ token_type_ids,
330
+ ) = super()._prepare_prefill_inputs(*args, **kwargs)
506
331
 
507
- if position_embed is not None:
508
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
332
+ # chunked_attention_mask shape
333
+ chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
509
334
 
510
- cache_position = torch.arange(0, query_length + padding_size, dtype=torch.int32).unsqueeze(0)
335
+ # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
336
+ padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
337
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
338
+ cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
339
+ position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
340
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
511
341
 
512
342
  return (
513
343
  inputs,
@@ -518,7 +348,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
518
348
  position_embed,
519
349
  padded_cache_lengths,
520
350
  query_length,
521
- token_type_ids_padded,
351
+ token_type_ids,
522
352
  )
523
353
 
524
354
  def prefill_forward(
@@ -541,65 +371,73 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
541
371
  (
542
372
  inputs,
543
373
  cache_position,
544
- padded_attention_mask,
374
+ chunked_attention_mask,
545
375
  out_buffers,
546
376
  position_ids,
547
377
  position_embed,
548
378
  padded_cache_lengths,
549
379
  query_length,
550
- token_type_ids_padded,
380
+ token_type_ids,
551
381
  ) = self._prepare_prefill_inputs(
552
382
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
553
383
  )
554
- if not is_external_block_tables:
555
- local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
556
- self.dec_attn_mask[batch_idx : batch_idx + 1] = padded_attention_mask[:1]
557
-
558
- if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
559
- chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
560
-
561
- # Process input in chunks of size `prefill_chunk_size`
562
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
563
- # Extract the current chunk of inputs and cache positions
564
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
565
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
566
- position_ids_chunk = (
567
- position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
568
- if position_ids is not None
569
- else None
384
+
385
+ step = 0
386
+ while step < query_length:
387
+ # Check if the prefill chunk is an image prefill
388
+ is_image_prefill = torch.all(
389
+ token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
390
+ )
391
+ prefill_chunk_size = (
392
+ self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
570
393
  )
571
394
 
572
- if self.rbln_config.use_attention_mask:
573
- if self.rbln_config.use_position_ids:
574
- chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = (
575
- padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
576
- )
577
-
578
- # Define query position
579
- query_position = (
580
- torch.sum(
581
- chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
582
- dim=-1,
583
- dtype=torch.int16,
584
- ).squeeze(0)
585
- - 1
395
+ # Check if the prefill chunk is a text prefill which have image_tokens in it.
396
+ is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
397
+ token_type_ids[:, step : step + prefill_chunk_size] == 1
586
398
  )
587
- if token_type_ids_padded[:, step] == 1:
588
- if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
589
- raise ValueError("All tokens of image_prefill should be the same image.")
590
- else:
591
- logits = self.image_prefill(
592
- input_chunk,
593
- cache_pos_chunk,
594
- block_tables,
595
- local_block_tables,
596
- query_position,
597
- chunked_attention_mask,
598
- position_ids_chunk,
599
- out=out_buffers,
600
- )
399
+
400
+ # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
401
+ is_cross_block_boundary = (
402
+ step // self.rbln_config.kvcache_block_size
403
+ != (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
404
+ )
405
+
406
+ # Check if the prefill chunk is the last chunk
407
+ is_last_chunk = step + prefill_chunk_size >= query_length
408
+
409
+ if is_cross_block_boundary:
410
+ padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
411
+ padded_cache_lengths += padding_size
412
+
413
+ # if text_prefill end with image_tokens, we only treat the text part.
414
+ num_processed_tokens = prefill_chunk_size
415
+ if is_text_prefill_with_image_tokens:
416
+ first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
417
+ num_processed_tokens = first_image_token_idx
418
+ if is_last_chunk:
419
+ num_processed_tokens = query_length - step
420
+
421
+ input_chunk = inputs[:, step : step + prefill_chunk_size]
422
+ cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
423
+ position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
424
+ chunked_attention_mask[
425
+ :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
426
+ ] = 1
427
+ query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
428
+
429
+ if is_image_prefill:
430
+ logits = self.image_prefill(
431
+ input_chunk,
432
+ cache_pos_chunk,
433
+ block_tables,
434
+ local_block_tables,
435
+ query_position,
436
+ chunked_attention_mask,
437
+ position_ids_chunk,
438
+ out=out_buffers,
439
+ )
601
440
  else:
602
- # Forward pass for the current chunk
603
441
  logits = self.prefill(
604
442
  input_chunk,
605
443
  cache_pos_chunk,
@@ -611,6 +449,11 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
611
449
  out=out_buffers,
612
450
  )
613
451
 
452
+ step += num_processed_tokens
453
+
454
+ if not is_external_block_tables:
455
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
456
+
614
457
  return RBLNGemma3ForCausalLMOutput(
615
458
  logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
616
459
  )
@@ -757,13 +600,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
757
600
 
758
601
  @classmethod
759
602
  def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
760
- if rbln_config.prefill_chunk_size is None:
761
- rbln_config.prefill_chunk_size = model.config.mm_tokens_per_image
603
+ if rbln_config.image_prefill_chunk_size is None:
604
+ rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
762
605
 
763
- if rbln_config.prefill_chunk_size != model.config.mm_tokens_per_image:
764
- logger.warning(
765
- f"Prefill chunk size is different from mm_tokens_per_image: {rbln_config.prefill_chunk_size} != {model.config.mm_tokens_per_image}"
606
+ if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
607
+ raise ValueError(
608
+ f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
766
609
  )
610
+
767
611
  return rbln_config
768
612
 
769
613
  @classmethod
@@ -777,14 +621,22 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
777
621
  # Update rbln_config with super class
778
622
  rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
779
623
 
780
- # Assume that prefill compile config is at index 0
781
- compile_cfgs = rbln_config.compile_cfgs
624
+ if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
625
+ raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
626
+
627
+ # Update image prefill compile config
628
+ img_prefill_input_info = cls.get_input_info(
629
+ batch_size=1,
630
+ query_length=rbln_config.image_prefill_chunk_size,
631
+ rbln_config=rbln_config,
632
+ model_config=model_config,
633
+ )
782
634
  image_prefill_compile_config = RBLNCompileConfig(
783
- compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
635
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
784
636
  )
785
637
  # Insert image_prefill compile config at index 1
786
- image_idx = 1
787
- compile_cfgs.insert(image_idx, image_prefill_compile_config)
638
+ compile_cfgs = rbln_config.compile_cfgs
639
+ compile_cfgs.insert(1, image_prefill_compile_config)
788
640
  rbln_config.set_compile_cfgs(compile_cfgs)
789
641
 
790
642
  return rbln_config
@@ -840,11 +692,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
840
692
  )
841
693
 
842
694
  image_prefill_compile_config = rbln_compile_configs[1]
695
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
696
+ fill=0, static_tensors=static_tensors
697
+ )
843
698
  wrapped_model.phase = "image_prefill"
844
699
  compiled_image_prefill = compile_model(
845
700
  wrapped_model,
846
701
  image_prefill_compile_config,
847
- prefill_example_inputs,
702
+ image_prefill_example_inputs,
848
703
  context,
849
704
  rbln_config.quantization,
850
705
  )
@@ -884,12 +739,14 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
884
739
  tensor_type="pt",
885
740
  device=rbln_config.device_map["prefill"],
886
741
  activate_profiler=rbln_config.activate_profiler,
742
+ timeout=rbln_config.timeout,
887
743
  ),
888
744
  rebel.Runtime(
889
745
  compiled_models[1],
890
746
  tensor_type="pt",
891
747
  device=rbln_config.device_map["image_prefill"],
892
748
  activate_profiler=rbln_config.activate_profiler,
749
+ timeout=rbln_config.timeout,
893
750
  ),
894
751
  *[
895
752
  rebel.Runtime(
@@ -897,6 +754,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
897
754
  tensor_type="pt",
898
755
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
899
756
  activate_profiler=rbln_config.activate_profiler,
757
+ timeout=rbln_config.timeout,
900
758
  )
901
759
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
902
760
  ],
@@ -15,6 +15,11 @@
15
15
  from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+ from ...models.clip import RBLNCLIPVisionModelConfig
20
+
21
+
22
+ logger = get_logger(__name__)
18
23
 
19
24
 
20
25
  class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
@@ -50,5 +55,17 @@ class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
50
55
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
56
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
57
 
53
- self.vision_tower = vision_tower
58
+ self.vision_tower = self.init_submodule_config(
59
+ RBLNCLIPVisionModelConfig,
60
+ vision_tower,
61
+ )
62
+
63
+ if self.vision_tower.output_hidden_states is False:
64
+ raise ValueError(
65
+ f"LlavaNext requires output_hidden_states to be True, but found output_hidden_states={self.vision_tower.output_hidden_states}. "
66
+ f"Please compile again with the correct argument."
67
+ )
68
+ else:
69
+ self.vision_tower.output_hidden_states = True
70
+
54
71
  self.language_model = language_model
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_qwen3 import RBLNQwen3ForCausalLMConfig, RBLNQwen3ModelConfig
16
+ from .modeling_qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3Model
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNQwen3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Qwen3 models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNQwen3ForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=40960,
32
+ tensor_parallel_size=4,
33
+ kvcache_partition_len=16384
34
+ )
35
+
36
+ # Use the configuration with from_pretrained
37
+ model = RBLNQwen3ForCausalLM.from_pretrained(
38
+ "Qwen/Qwen3-4B",
39
+ export=True,
40
+ rbln_config=config
41
+ )
42
+ ```
43
+ """
44
+
45
+
46
+ class RBLNQwen3ModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
47
+ """
48
+ Configuration class for RBLN Qwen3 models.
49
+
50
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
51
+
52
+ Example usage:
53
+ ```python
54
+ from optimum.rbln import RBLNQwen3Model, RBLNQwen3ModelConfig
55
+
56
+ # Create a configuration object
57
+ config = RBLNQwen3ModelConfig(
58
+ batch_size=1,
59
+ max_seq_len=40960,
60
+ tensor_parallel_size=4,
61
+ kvcache_partition_len=16384
62
+ )
63
+
64
+ # Use the configuration with from_pretrained
65
+ model = RBLNQwen3Model.from_pretrained(
66
+ "Qwen/Qwen3-Embedding-4B",
67
+ export=True,
68
+ rbln_config=config
69
+ )
70
+ ```
71
+ """