optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +36 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +28 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +35 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +64 -258
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -12,43 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
15
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
16
 
19
17
  import rebel
20
18
  import torch
21
19
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
20
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
21
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
22
  from transformers.modeling_utils import no_init_weights
30
23
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
24
 
32
25
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
26
  from ....modeling import RBLNModel
27
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
28
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
34
29
  from ..decoderonly.modeling_decoderonly import (
35
- RBLNDecoderOnlyForCausalLMOutput,
36
30
  RBLNDecoderOnlyModelForCausalLM,
37
- RBLNRuntimeModel,
38
31
  )
39
32
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
40
33
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
34
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
41
35
 
42
36
 
43
37
  if TYPE_CHECKING:
44
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
39
 
46
40
 
47
- @dataclass
48
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
49
- attention_mask: Optional[torch.Tensor] = None
50
-
51
-
52
41
  class LoopVisionTower:
53
42
  def __init__(self, vision_tower: RBLNModel) -> None:
54
43
  self.vision_tower = vision_tower
@@ -201,7 +190,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
201
190
 
202
191
  def _update_model_kwargs_for_generation(
203
192
  self,
204
- outputs: RBLNDecoderOnlyForCausalLMOutput,
193
+ outputs: RBLNDecoderOnlyOutput,
205
194
  model_kwargs: Dict[str, Any],
206
195
  **kwargs,
207
196
  ) -> Dict[str, Any]:
@@ -258,19 +247,47 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
258
247
 
259
248
  return inputs_embeds
260
249
 
250
+ def get_padded_cache_position(
251
+ self,
252
+ cache_position: torch.Tensor, # shape: [1, seq_len]
253
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
254
+ ) -> torch.Tensor:
255
+ seq_len = cache_position[0][-1].item() + 1
256
+
257
+ # Find image start positions
258
+ image_starts = [
259
+ s
260
+ for s in torch.where(token_type_ids == 1)[1]
261
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
262
+ ]
263
+
264
+ # Initialize padded tensors
265
+ padded_input_len = seq_len
266
+ for image_start in image_starts:
267
+ pad_needed = (
268
+ self.rbln_config.image_prefill_chunk_size
269
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
270
+ ) % self.rbln_config.image_prefill_chunk_size
271
+ padded_input_len += pad_needed
272
+
273
+ return torch.cat(
274
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
275
+ dim=1,
276
+ )
277
+
261
278
  def forward(
262
279
  self,
263
280
  input_ids: torch.LongTensor = None,
281
+ attention_mask: torch.Tensor = None,
282
+ token_type_ids: torch.Tensor = None,
264
283
  pixel_values: torch.FloatTensor = None,
265
- attention_mask: Optional[torch.Tensor] = None,
266
284
  cache_position: Optional[torch.LongTensor] = None,
267
285
  inputs_embeds: Optional[torch.FloatTensor] = None,
268
286
  generate_idx: Optional[torch.Tensor] = None,
269
287
  padded_cache_lengths: Optional[torch.Tensor] = None,
270
288
  position_ids: Optional[torch.Tensor] = None,
271
- token_type_ids: Optional[torch.Tensor] = None,
272
289
  **lm_kwargs: Dict[str, Any],
273
- ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
290
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
274
291
  # prefill
275
292
  if cache_position is None:
276
293
  logits = []
@@ -279,12 +296,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
279
296
 
280
297
  for b_idx in range(batch_size):
281
298
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
299
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
300
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
301
+
282
302
  output = self.language_model.prefill_decoder(
283
303
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
284
304
  attention_mask=attention_mask[b_idx],
285
305
  cache_position=cache_position,
286
306
  batch_idx=b_idx,
287
- token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
307
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
288
308
  )
289
309
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
290
310
  logits.append(output.logits)
@@ -308,217 +328,11 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
308
328
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
309
329
  ).logits
310
330
 
311
- return RBLNDecoderOnlyForCausalLMOutput(
331
+ return RBLNDecoderOnlyOutput(
312
332
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
313
333
  )
314
334
 
315
335
 
316
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
317
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
318
- super().__init__(*args, **kwargs)
319
- self.image_prefill = image_prefill # FIXME(taehoon)
320
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
321
- self.decode = self.runtime if self.phase == "decode" else None
322
-
323
- def _prepare_prefill_inputs(self, *args, **kwargs):
324
- (
325
- inputs,
326
- cache_position,
327
- chunked_attention_mask,
328
- out_buffers,
329
- position_ids,
330
- position_embed,
331
- padded_cache_lengths,
332
- query_length,
333
- token_type_ids,
334
- ) = super()._prepare_prefill_inputs(*args, **kwargs)
335
-
336
- # chunked_attention_mask shape
337
- chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
338
-
339
- # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
340
- if self.rbln_config.use_image_prefill:
341
- padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
342
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
343
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
344
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
345
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
346
-
347
- return (
348
- inputs,
349
- cache_position,
350
- chunked_attention_mask,
351
- out_buffers,
352
- position_ids,
353
- position_embed,
354
- padded_cache_lengths,
355
- query_length,
356
- token_type_ids,
357
- )
358
-
359
- def prefill_forward(
360
- self,
361
- inputs: torch.Tensor,
362
- cache_position: torch.Tensor = None,
363
- attention_mask: Optional[torch.Tensor] = None,
364
- batch_idx: int = None,
365
- block_tables: torch.Tensor = None,
366
- is_external_block_tables: bool = None,
367
- position_embed: Optional[torch.Tensor] = None,
368
- token_type_ids: Optional[torch.Tensor] = None,
369
- local_block_tables: Optional[torch.Tensor] = None,
370
- ) -> torch.FloatTensor:
371
- """
372
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
373
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
374
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
375
- """
376
- (
377
- inputs,
378
- cache_position,
379
- chunked_attention_mask,
380
- out_buffers,
381
- position_ids,
382
- position_embed,
383
- padded_cache_lengths,
384
- query_length,
385
- token_type_ids,
386
- ) = self._prepare_prefill_inputs(
387
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
388
- )
389
-
390
- step = 0
391
- while step < query_length:
392
- # Check if the prefill chunk is an image prefill
393
- is_image_prefill = self.rbln_config.use_image_prefill and torch.all(
394
- token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
395
- )
396
- prefill_chunk_size = (
397
- self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
398
- )
399
-
400
- # Check if the prefill chunk is a text prefill which have image_tokens in it.
401
- is_text_prefill_with_image_tokens = (
402
- self.rbln_config.use_image_prefill
403
- and not is_image_prefill
404
- and torch.any(token_type_ids[:, step : step + prefill_chunk_size] == 1)
405
- )
406
-
407
- # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
408
- is_cross_block_boundary = (
409
- step // self.rbln_config.kvcache_block_size
410
- != (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
411
- )
412
-
413
- # Check if the prefill chunk is the last chunk
414
- is_last_chunk = step + prefill_chunk_size >= query_length
415
-
416
- if is_cross_block_boundary:
417
- padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
418
- padded_cache_lengths += padding_size
419
-
420
- # if text_prefill end with image_tokens, we only treat the text part.
421
- num_processed_tokens = prefill_chunk_size
422
- if is_text_prefill_with_image_tokens:
423
- first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
424
- num_processed_tokens = first_image_token_idx.item()
425
- if is_last_chunk:
426
- num_processed_tokens = query_length - step
427
-
428
- input_chunk = inputs[:, step : step + prefill_chunk_size]
429
- cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
430
- position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
431
- chunked_attention_mask[
432
- :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
433
- ] = 1
434
- query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
435
-
436
- if is_image_prefill:
437
- logits = self.image_prefill(
438
- input_chunk,
439
- cache_pos_chunk,
440
- block_tables,
441
- local_block_tables,
442
- query_position,
443
- chunked_attention_mask,
444
- position_ids_chunk,
445
- out=out_buffers,
446
- )
447
- else:
448
- logits = self.prefill(
449
- input_chunk,
450
- cache_pos_chunk,
451
- block_tables,
452
- local_block_tables,
453
- query_position,
454
- chunked_attention_mask,
455
- position_ids_chunk,
456
- out=out_buffers,
457
- )
458
-
459
- step += num_processed_tokens
460
-
461
- if not is_external_block_tables:
462
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
463
-
464
- return RBLNGemma3ForCausalLMOutput(
465
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
466
- )
467
-
468
- def decode_forward(
469
- self,
470
- inputs: torch.Tensor,
471
- cache_position: torch.Tensor = None,
472
- block_tables: torch.Tensor = None,
473
- is_external_block_tables: bool = None,
474
- attention_mask: Optional[torch.Tensor] = None,
475
- position_embed: Optional[torch.Tensor] = None,
476
- position_ids: Optional[torch.Tensor] = None,
477
- local_block_tables: Optional[torch.Tensor] = None,
478
- ) -> torch.FloatTensor:
479
- batch_size = inputs.shape[0]
480
- if batch_size != self.batch_size:
481
- raise RuntimeError(
482
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
483
- )
484
-
485
- if batch_size != cache_position.shape[0]:
486
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
487
-
488
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
489
- if is_external_block_tables:
490
- if attention_mask is None:
491
- raise ValueError("attention_mask should be provided with external block tables.")
492
- if local_block_tables is None:
493
- raise ValueError("local_block_tables should be provided with external block tables.")
494
- else:
495
- local_block_tables = (
496
- local_block_tables
497
- if local_block_tables is not None
498
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
499
- )
500
- if self.rbln_config.use_attention_mask and attention_mask is None:
501
- for b_idx in range(batch_size):
502
- decoding_step = cache_position[b_idx].item()
503
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
504
- raise ValueError(
505
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
506
- )
507
- self.dec_attn_mask[b_idx, decoding_step] = 1
508
-
509
- attention_mask = self.dec_attn_mask
510
-
511
- if self.batch_size < block_tables.shape[0]:
512
- block_tables = block_tables[: self.batch_size]
513
-
514
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
515
- attention_mask = attention_mask[: self.batch_size]
516
-
517
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
518
-
519
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
520
-
521
-
522
336
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
523
337
  """
524
338
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -532,52 +346,34 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
532
346
 
533
347
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
534
348
 
535
- def __post_init__(self, **kwargs):
536
- main_input_name = self.main_input_name
537
-
538
- if self.rbln_config.use_inputs_embeds:
539
- main_input_name = "inputs_embeds"
540
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
541
- self.embed_tokens = self._create_embedding_layer()
542
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
543
- else:
544
- self.embed_tokens = None
545
-
349
+ def setup_runtime(self):
546
350
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
547
351
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
548
- block_tables = torch.zeros(
549
- self.rbln_config.batch_size,
550
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
551
- dtype=torch.int16,
552
- ).fill_(-1)
553
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
352
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
353
+
354
+ common_kwargs = {
355
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
356
+ "embed_tokens": self.embed_tokens,
357
+ "dec_attn_mask": dec_attn_mask,
358
+ "page_table_manager": page_table_manager,
359
+ "rbln_config": self.rbln_config,
360
+ }
554
361
 
555
362
  self.prefill_decoder = RBLNGemma3RuntimeModel(
556
363
  runtime=self.model[0],
557
364
  image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
558
- main_input_name=main_input_name,
559
- embed_tokens=self.embed_tokens,
560
365
  phase="prefill",
561
366
  batch_size=self.rbln_config.batch_size,
562
- dec_attn_mask=dec_attn_mask,
563
- block_tables=block_tables,
564
- vocab_size=self.config.vocab_size,
565
- free_block_pool=free_block_pool,
566
- rbln_config=self.rbln_config,
367
+ **common_kwargs,
567
368
  )
568
369
 
569
370
  self.decoders = {}
570
371
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
571
372
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
572
373
  runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
573
- main_input_name=main_input_name,
574
- embed_tokens=self.embed_tokens,
575
374
  phase="decode",
576
375
  batch_size=batch_size,
577
- dec_attn_mask=dec_attn_mask,
578
- block_tables=block_tables,
579
- free_block_pool=free_block_pool,
580
- rbln_config=self.rbln_config,
376
+ **common_kwargs,
581
377
  )
582
378
 
583
379
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -607,7 +403,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
607
403
  return rbln_config
608
404
 
609
405
  @classmethod
610
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
406
+ def _update_submodule_config(
407
+ cls,
408
+ model: "PreTrainedModel",
409
+ rbln_config: RBLNModelConfig,
410
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
411
+ ):
611
412
  if rbln_config.image_prefill_chunk_size is None:
612
413
  rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
613
414
 
@@ -633,6 +434,11 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
633
434
  raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
634
435
 
635
436
  if rbln_config.use_image_prefill:
437
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
438
+ raise NotImplementedError(
439
+ "Not implemented for different prefill chunk sizes between text and image prefill."
440
+ )
441
+
636
442
  # Update image prefill compile config
637
443
  img_prefill_input_info = cls.get_input_info(
638
444
  batch_size=1,
@@ -47,6 +47,8 @@ class RBLNGPT2Model(RBLNDecoderOnlyModel):
47
47
 
48
48
  A class to convert and run pre-trained transformers based GPT2Model model on RBLN devices.
49
49
  It implements the methods to convert a pre-trained transformers GPT2Model model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
50
52
  """
51
53
 
52
54
  _decoder_wrapper_cls = GPT2Wrapper
@@ -0,0 +1,10 @@
1
+ from .configuration_grounding_dino import (
2
+ RBLNGroundingDinoDecoderConfig,
3
+ RBLNGroundingDinoEncoderConfig,
4
+ RBLNGroundingDinoForObjectDetectionConfig,
5
+ )
6
+ from .modeling_grounding_dino import (
7
+ RBLNGroundingDinoDecoder,
8
+ RBLNGroundingDinoEncoder,
9
+ RBLNGroundingDinoForObjectDetection,
10
+ )
@@ -0,0 +1,86 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at:
4
+
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from typing import Any, List, Optional, Tuple, Union
14
+
15
+ import torch
16
+
17
+ from ...configuration_generic import RBLNImageModelConfig, RBLNModelConfig
18
+
19
+
20
+ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
21
+ submodules = [
22
+ "text_backbone",
23
+ "backbone",
24
+ "encoder",
25
+ "decoder",
26
+ ]
27
+
28
+ def __init__(
29
+ self,
30
+ batch_size: Optional[int] = None,
31
+ encoder: Optional["RBLNGroundingDinoEncoderConfig"] = None,
32
+ decoder: Optional["RBLNGroundingDinoDecoderConfig"] = None,
33
+ text_backbone: Optional["RBLNModelConfig"] = None,
34
+ backbone: Optional["RBLNModelConfig"] = None,
35
+ output_attentions: Optional[bool] = False,
36
+ output_hidden_states: Optional[bool] = False,
37
+ **kwargs: Any,
38
+ ):
39
+ """
40
+ Args:
41
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
42
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
43
+
44
+ Raises:
45
+ ValueError: If batch_size is not a positive integer.
46
+ """
47
+ super().__init__(**kwargs)
48
+ self.encoder = encoder
49
+ self.decoder = decoder
50
+ self.text_backbone = text_backbone
51
+ self.backbone = backbone
52
+ self.output_attentions = output_attentions
53
+ self.output_hidden_states = output_hidden_states
54
+
55
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
56
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
57
+
58
+
59
+ class RBLNGroundingDinoComponentConfig(RBLNImageModelConfig):
60
+ def __init__(
61
+ self,
62
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
63
+ batch_size: Optional[int] = None,
64
+ spatial_shapes_list: Optional[List[Tuple[int, int]]] = None,
65
+ output_attentions: Optional[bool] = False,
66
+ output_hidden_states: Optional[bool] = False,
67
+ **kwargs: Any,
68
+ ):
69
+ super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
70
+ self.spatial_shapes_list = spatial_shapes_list
71
+ self.output_attentions = output_attentions
72
+ self.output_hidden_states = output_hidden_states
73
+
74
+ @property
75
+ def spatial_shapes(self):
76
+ if self.spatial_shapes_list is None:
77
+ raise ValueError("Spatial shapes are not defined. Please set them before accessing.")
78
+ return torch.tensor(self.spatial_shapes_list)
79
+
80
+
81
+ class RBLNGroundingDinoEncoderConfig(RBLNGroundingDinoComponentConfig):
82
+ pass
83
+
84
+
85
+ class RBLNGroundingDinoDecoderConfig(RBLNGroundingDinoComponentConfig):
86
+ pass