optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +32 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +24 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +31 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +25 -251
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -12,43 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
15
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
16
 
19
17
  import rebel
20
18
  import torch
21
19
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
20
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
21
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
22
  from transformers.modeling_utils import no_init_weights
30
23
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
24
 
32
25
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
26
  from ....modeling import RBLNModel
27
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
28
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
34
29
  from ..decoderonly.modeling_decoderonly import (
35
- RBLNDecoderOnlyForCausalLMOutput,
36
30
  RBLNDecoderOnlyModelForCausalLM,
37
- RBLNRuntimeModel,
38
31
  )
39
32
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
40
33
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
34
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
41
35
 
42
36
 
43
37
  if TYPE_CHECKING:
44
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
39
 
46
40
 
47
- @dataclass
48
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
49
- attention_mask: Optional[torch.Tensor] = None
50
-
51
-
52
41
  class LoopVisionTower:
53
42
  def __init__(self, vision_tower: RBLNModel) -> None:
54
43
  self.vision_tower = vision_tower
@@ -201,7 +190,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
201
190
 
202
191
  def _update_model_kwargs_for_generation(
203
192
  self,
204
- outputs: RBLNDecoderOnlyForCausalLMOutput,
193
+ outputs: RBLNDecoderOnlyOutput,
205
194
  model_kwargs: Dict[str, Any],
206
195
  **kwargs,
207
196
  ) -> Dict[str, Any]:
@@ -298,7 +287,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
298
287
  padded_cache_lengths: Optional[torch.Tensor] = None,
299
288
  position_ids: Optional[torch.Tensor] = None,
300
289
  **lm_kwargs: Dict[str, Any],
301
- ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
290
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
302
291
  # prefill
303
292
  if cache_position is None:
304
293
  logits = []
@@ -339,213 +328,11 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
339
328
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
340
329
  ).logits
341
330
 
342
- return RBLNDecoderOnlyForCausalLMOutput(
331
+ return RBLNDecoderOnlyOutput(
343
332
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
344
333
  )
345
334
 
346
335
 
347
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
348
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
349
- super().__init__(*args, **kwargs)
350
- self.image_prefill = image_prefill # FIXME(taehoon)
351
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
352
- self.decode = self.runtime if self.phase == "decode" else None
353
-
354
- def _prepare_prefill_inputs(self, *args, **kwargs):
355
- (
356
- inputs,
357
- cache_position,
358
- chunked_attention_mask,
359
- out_buffers,
360
- position_ids,
361
- position_embed,
362
- padded_cache_lengths,
363
- query_length,
364
- token_type_ids,
365
- ) = super()._prepare_prefill_inputs(*args, **kwargs)
366
-
367
- # chunked_attention_mask shape
368
- chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
369
-
370
- # In case of Gemma3ForConditionalGeneration, the loop counter may not be a prefill_chunk_size,
371
- # so we cannot guarantee that the last chunk starts at a position that is a multiple of prefill_chunk_size.
372
- if self.rbln_config.use_image_prefill:
373
- padding_size = self.rbln_config.image_prefill_chunk_size
374
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
375
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
376
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
377
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
378
-
379
- return (
380
- inputs,
381
- cache_position,
382
- chunked_attention_mask,
383
- out_buffers,
384
- position_ids,
385
- position_embed,
386
- padded_cache_lengths,
387
- query_length,
388
- token_type_ids,
389
- )
390
-
391
- def prefill_forward(
392
- self,
393
- inputs: torch.Tensor,
394
- cache_position: torch.Tensor = None,
395
- attention_mask: Optional[torch.Tensor] = None,
396
- batch_idx: int = None,
397
- block_tables: torch.Tensor = None,
398
- is_external_block_tables: bool = None,
399
- position_embed: Optional[torch.Tensor] = None,
400
- token_type_ids: Optional[torch.Tensor] = None,
401
- local_block_tables: Optional[torch.Tensor] = None,
402
- ) -> torch.FloatTensor:
403
- """
404
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
405
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
406
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
407
- """
408
- (
409
- inputs,
410
- cache_position,
411
- chunked_attention_mask,
412
- out_buffers,
413
- position_ids,
414
- position_embed,
415
- padded_cache_lengths,
416
- query_length,
417
- token_type_ids,
418
- ) = self._prepare_prefill_inputs(
419
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
420
- )
421
-
422
- step = 0
423
- while step < query_length:
424
- if self.rbln_config.use_image_prefill:
425
- # Check if the prefill chunk is an image prefill
426
- is_image_prefill = torch.all(
427
- token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
428
- )
429
- # Check if the prefill chunk is a text prefill which have image_tokens in it.
430
- is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
431
- token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
432
- )
433
- else:
434
- is_image_prefill, is_text_prefill_with_image_tokens = False, False
435
-
436
- # Check if the prefill chunk is the last chunk
437
- is_last_chunk = step + self.rbln_config.prefill_chunk_size >= query_length
438
-
439
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
440
- cache_pos_chunk = (
441
- cache_position[:, step : step + self.rbln_config.prefill_chunk_size] + padded_cache_lengths
442
- )
443
- position_ids_chunk = position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
444
-
445
- # if text_prefill end with image_tokens, we only treat the text part.
446
- num_processed_tokens = self.rbln_config.prefill_chunk_size
447
- current_padded_cache_lengths = 0
448
- if is_text_prefill_with_image_tokens:
449
- first_image_token_idx = torch.where(
450
- token_type_ids[:, step : step + self.rbln_config.prefill_chunk_size] == 1
451
- )[1][0]
452
- num_processed_tokens = first_image_token_idx.item()
453
- current_padded_cache_lengths = self.rbln_config.prefill_chunk_size - num_processed_tokens
454
- if is_last_chunk:
455
- num_processed_tokens = query_length - step
456
-
457
- chunked_attention_mask[
458
- :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
459
- ] = 1
460
- query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
461
-
462
- if is_image_prefill:
463
- logits = self.image_prefill(
464
- input_chunk,
465
- cache_pos_chunk,
466
- block_tables,
467
- local_block_tables,
468
- query_position,
469
- chunked_attention_mask,
470
- position_ids_chunk,
471
- out=out_buffers,
472
- )
473
- else:
474
- logits = self.prefill(
475
- input_chunk,
476
- cache_pos_chunk,
477
- block_tables,
478
- local_block_tables,
479
- query_position,
480
- chunked_attention_mask,
481
- position_ids_chunk,
482
- out=out_buffers,
483
- )
484
-
485
- padded_cache_lengths += current_padded_cache_lengths
486
- step += num_processed_tokens
487
-
488
- if not is_external_block_tables:
489
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
490
-
491
- return RBLNGemma3ForCausalLMOutput(
492
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
493
- )
494
-
495
- def decode_forward(
496
- self,
497
- inputs: torch.Tensor,
498
- cache_position: torch.Tensor = None,
499
- block_tables: torch.Tensor = None,
500
- is_external_block_tables: bool = None,
501
- attention_mask: Optional[torch.Tensor] = None,
502
- position_embed: Optional[torch.Tensor] = None,
503
- position_ids: Optional[torch.Tensor] = None,
504
- local_block_tables: Optional[torch.Tensor] = None,
505
- ) -> torch.FloatTensor:
506
- batch_size = inputs.shape[0]
507
- if batch_size != self.batch_size:
508
- raise RuntimeError(
509
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
510
- )
511
-
512
- if batch_size != cache_position.shape[0]:
513
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
514
-
515
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
516
- if is_external_block_tables:
517
- if attention_mask is None:
518
- raise ValueError("attention_mask should be provided with external block tables.")
519
- if local_block_tables is None:
520
- raise ValueError("local_block_tables should be provided with external block tables.")
521
- else:
522
- local_block_tables = (
523
- local_block_tables
524
- if local_block_tables is not None
525
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
526
- )
527
- if self.rbln_config.use_attention_mask and attention_mask is None:
528
- for b_idx in range(batch_size):
529
- decoding_step = cache_position[b_idx].item()
530
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
531
- raise ValueError(
532
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
533
- )
534
- self.dec_attn_mask[b_idx, decoding_step] = 1
535
-
536
- attention_mask = self.dec_attn_mask
537
-
538
- if self.batch_size < block_tables.shape[0]:
539
- block_tables = block_tables[: self.batch_size]
540
-
541
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
542
- attention_mask = attention_mask[: self.batch_size]
543
-
544
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
545
-
546
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
547
-
548
-
549
336
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
550
337
  """
551
338
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -559,52 +346,34 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
559
346
 
560
347
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
561
348
 
562
- def __post_init__(self, **kwargs):
563
- main_input_name = self.main_input_name
564
-
565
- if self.rbln_config.use_inputs_embeds:
566
- main_input_name = "inputs_embeds"
567
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
568
- self.embed_tokens = self._create_embedding_layer()
569
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
570
- else:
571
- self.embed_tokens = None
572
-
349
+ def setup_runtime(self):
573
350
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
574
351
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
575
- block_tables = torch.zeros(
576
- self.rbln_config.batch_size,
577
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
578
- dtype=torch.int16,
579
- ).fill_(-1)
580
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
352
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
353
+
354
+ common_kwargs = {
355
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
356
+ "embed_tokens": self.embed_tokens,
357
+ "dec_attn_mask": dec_attn_mask,
358
+ "page_table_manager": page_table_manager,
359
+ "rbln_config": self.rbln_config,
360
+ }
581
361
 
582
362
  self.prefill_decoder = RBLNGemma3RuntimeModel(
583
363
  runtime=self.model[0],
584
364
  image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
585
- main_input_name=main_input_name,
586
- embed_tokens=self.embed_tokens,
587
365
  phase="prefill",
588
366
  batch_size=self.rbln_config.batch_size,
589
- dec_attn_mask=dec_attn_mask,
590
- block_tables=block_tables,
591
- vocab_size=self.config.vocab_size,
592
- free_block_pool=free_block_pool,
593
- rbln_config=self.rbln_config,
367
+ **common_kwargs,
594
368
  )
595
369
 
596
370
  self.decoders = {}
597
371
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
598
372
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
599
373
  runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
600
- main_input_name=main_input_name,
601
- embed_tokens=self.embed_tokens,
602
374
  phase="decode",
603
375
  batch_size=batch_size,
604
- dec_attn_mask=dec_attn_mask,
605
- block_tables=block_tables,
606
- free_block_pool=free_block_pool,
607
- rbln_config=self.rbln_config,
376
+ **common_kwargs,
608
377
  )
609
378
 
610
379
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -634,7 +403,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
634
403
  return rbln_config
635
404
 
636
405
  @classmethod
637
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
406
+ def _update_submodule_config(
407
+ cls,
408
+ model: "PreTrainedModel",
409
+ rbln_config: RBLNModelConfig,
410
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
411
+ ):
638
412
  if rbln_config.image_prefill_chunk_size is None:
639
413
  rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
640
414
 
@@ -47,6 +47,8 @@ class RBLNGPT2Model(RBLNDecoderOnlyModel):
47
47
 
48
48
  A class to convert and run pre-trained transformers based GPT2Model model on RBLN devices.
49
49
  It implements the methods to convert a pre-trained transformers GPT2Model model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
50
52
  """
51
53
 
52
54
  _decoder_wrapper_cls = GPT2Wrapper
@@ -0,0 +1,10 @@
1
+ from .configuration_grounding_dino import (
2
+ RBLNGroundingDinoDecoderConfig,
3
+ RBLNGroundingDinoEncoderConfig,
4
+ RBLNGroundingDinoForObjectDetectionConfig,
5
+ )
6
+ from .modeling_grounding_dino import (
7
+ RBLNGroundingDinoDecoder,
8
+ RBLNGroundingDinoEncoder,
9
+ RBLNGroundingDinoForObjectDetection,
10
+ )
@@ -0,0 +1,86 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at:
4
+
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from typing import Any, List, Optional, Tuple, Union
14
+
15
+ import torch
16
+
17
+ from ...configuration_generic import RBLNImageModelConfig, RBLNModelConfig
18
+
19
+
20
+ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
21
+ submodules = [
22
+ "text_backbone",
23
+ "backbone",
24
+ "encoder",
25
+ "decoder",
26
+ ]
27
+
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ encoder: Optional["RBLNGroundingDinoEncoderConfig"] = None,
32
+ decoder: Optional["RBLNGroundingDinoDecoderConfig"] = None,
33
+ text_backbone: Optional["RBLNModelConfig"] = None,
34
+ backbone: Optional["RBLNModelConfig"] = None,
35
+ output_attentions: Optional[bool] = False,
36
+ output_hidden_states: Optional[bool] = False,
37
+ **kwargs: Any,
38
+ ):
39
+ """
40
+ Args:
41
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
42
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
43
+
44
+ Raises:
45
+ ValueError: If batch_size is not a positive integer.
46
+ """
47
+ super().__init__(**kwargs)
48
+ self.encoder = encoder
49
+ self.decoder = decoder
50
+ self.text_backbone = text_backbone
51
+ self.backbone = backbone
52
+ self.output_attentions = output_attentions
53
+ self.output_hidden_states = output_hidden_states
54
+
55
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
56
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
57
+
58
+
59
+ class RBLNGroundingDinoComponentConfig(RBLNImageModelConfig):
60
+ def __init__(
61
+ self,
62
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
63
+ batch_size: Optional[int] = None,
64
+ spatial_shapes_list: Optional[List[Tuple[int, int]]] = None,
65
+ output_attentions: Optional[bool] = False,
66
+ output_hidden_states: Optional[bool] = False,
67
+ **kwargs: Any,
68
+ ):
69
+ super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
70
+ self.spatial_shapes_list = spatial_shapes_list
71
+ self.output_attentions = output_attentions
72
+ self.output_hidden_states = output_hidden_states
73
+
74
+ @property
75
+ def spatial_shapes(self):
76
+ if self.spatial_shapes_list is None:
77
+ raise ValueError("Spatial shapes are not defined. Please set them before accessing.")
78
+ return torch.tensor(self.spatial_shapes_list)
79
+
80
+
81
+ class RBLNGroundingDinoEncoderConfig(RBLNGroundingDinoComponentConfig):
82
+ pass
83
+
84
+
85
+ class RBLNGroundingDinoDecoderConfig(RBLNGroundingDinoComponentConfig):
86
+ pass