optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (90) hide show
  1. optimum/rbln/__init__.py +8 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +6 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +10 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -257
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  81. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  82. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  83. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  84. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  85. optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
  86. optimum/rbln/utils/runtime_utils.py +3 -3
  87. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
  88. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
  89. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
  90. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -12,43 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
15
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
16
 
19
17
  import rebel
20
18
  import torch
21
19
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
20
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
21
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
22
  from transformers.modeling_utils import no_init_weights
30
23
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
24
 
32
25
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
26
  from ....modeling import RBLNModel
27
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
28
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
34
29
  from ..decoderonly.modeling_decoderonly import (
35
- RBLNDecoderOnlyForCausalLMOutput,
36
30
  RBLNDecoderOnlyModelForCausalLM,
37
- RBLNRuntimeModel,
38
31
  )
39
32
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
40
33
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
34
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
41
35
 
42
36
 
43
37
  if TYPE_CHECKING:
44
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
39
 
46
40
 
47
- @dataclass
48
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyForCausalLMOutput):
49
- attention_mask: Optional[torch.Tensor] = None
50
-
51
-
52
41
  class LoopVisionTower:
53
42
  def __init__(self, vision_tower: RBLNModel) -> None:
54
43
  self.vision_tower = vision_tower
@@ -201,7 +190,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
201
190
 
202
191
  def _update_model_kwargs_for_generation(
203
192
  self,
204
- outputs: RBLNDecoderOnlyForCausalLMOutput,
193
+ outputs: RBLNDecoderOnlyOutput,
205
194
  model_kwargs: Dict[str, Any],
206
195
  **kwargs,
207
196
  ) -> Dict[str, Any]:
@@ -258,19 +247,47 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
258
247
 
259
248
  return inputs_embeds
260
249
 
250
+ def get_padded_cache_position(
251
+ self,
252
+ cache_position: torch.Tensor, # shape: [1, seq_len]
253
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
254
+ ) -> torch.Tensor:
255
+ seq_len = cache_position[0][-1].item() + 1
256
+
257
+ # Find image start positions
258
+ image_starts = [
259
+ s
260
+ for s in torch.where(token_type_ids == 1)[1]
261
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
262
+ ]
263
+
264
+ # Initialize padded tensors
265
+ padded_input_len = seq_len
266
+ for image_start in image_starts:
267
+ pad_needed = (
268
+ self.rbln_config.image_prefill_chunk_size
269
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
270
+ ) % self.rbln_config.image_prefill_chunk_size
271
+ padded_input_len += pad_needed
272
+
273
+ return torch.cat(
274
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
275
+ dim=1,
276
+ )
277
+
261
278
  def forward(
262
279
  self,
263
280
  input_ids: torch.LongTensor = None,
281
+ attention_mask: torch.Tensor = None,
282
+ token_type_ids: torch.Tensor = None,
264
283
  pixel_values: torch.FloatTensor = None,
265
- attention_mask: Optional[torch.Tensor] = None,
266
284
  cache_position: Optional[torch.LongTensor] = None,
267
285
  inputs_embeds: Optional[torch.FloatTensor] = None,
268
286
  generate_idx: Optional[torch.Tensor] = None,
269
287
  padded_cache_lengths: Optional[torch.Tensor] = None,
270
288
  position_ids: Optional[torch.Tensor] = None,
271
- token_type_ids: Optional[torch.Tensor] = None,
272
289
  **lm_kwargs: Dict[str, Any],
273
- ) -> Union[Tuple, RBLNDecoderOnlyForCausalLMOutput]:
290
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
274
291
  # prefill
275
292
  if cache_position is None:
276
293
  logits = []
@@ -279,12 +296,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
279
296
 
280
297
  for b_idx in range(batch_size):
281
298
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
299
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
300
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
301
+
282
302
  output = self.language_model.prefill_decoder(
283
303
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
284
304
  attention_mask=attention_mask[b_idx],
285
305
  cache_position=cache_position,
286
306
  batch_idx=b_idx,
287
- token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
307
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
288
308
  )
289
309
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
290
310
  logits.append(output.logits)
@@ -308,217 +328,11 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
308
328
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
309
329
  ).logits
310
330
 
311
- return RBLNDecoderOnlyForCausalLMOutput(
331
+ return RBLNDecoderOnlyOutput(
312
332
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
313
333
  )
314
334
 
315
335
 
316
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
317
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
318
- super().__init__(*args, **kwargs)
319
- self.image_prefill = image_prefill # FIXME(taehoon)
320
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
321
- self.decode = self.runtime if self.phase == "decode" else None
322
-
323
- def _prepare_prefill_inputs(self, *args, **kwargs):
324
- (
325
- inputs,
326
- cache_position,
327
- chunked_attention_mask,
328
- out_buffers,
329
- position_ids,
330
- position_embed,
331
- padded_cache_lengths,
332
- query_length,
333
- token_type_ids,
334
- ) = super()._prepare_prefill_inputs(*args, **kwargs)
335
-
336
- # chunked_attention_mask shape
337
- chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
338
-
339
- # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
340
- if self.rbln_config.use_image_prefill:
341
- padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
342
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
343
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
344
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
345
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
346
-
347
- return (
348
- inputs,
349
- cache_position,
350
- chunked_attention_mask,
351
- out_buffers,
352
- position_ids,
353
- position_embed,
354
- padded_cache_lengths,
355
- query_length,
356
- token_type_ids,
357
- )
358
-
359
- def prefill_forward(
360
- self,
361
- inputs: torch.Tensor,
362
- cache_position: torch.Tensor = None,
363
- attention_mask: Optional[torch.Tensor] = None,
364
- batch_idx: int = None,
365
- block_tables: torch.Tensor = None,
366
- is_external_block_tables: bool = None,
367
- position_embed: Optional[torch.Tensor] = None,
368
- token_type_ids: Optional[torch.Tensor] = None,
369
- local_block_tables: Optional[torch.Tensor] = None,
370
- ) -> torch.FloatTensor:
371
- """
372
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
373
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
374
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
375
- """
376
- (
377
- inputs,
378
- cache_position,
379
- chunked_attention_mask,
380
- out_buffers,
381
- position_ids,
382
- position_embed,
383
- padded_cache_lengths,
384
- query_length,
385
- token_type_ids,
386
- ) = self._prepare_prefill_inputs(
387
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
388
- )
389
-
390
- step = 0
391
- while step < query_length:
392
- # Check if the prefill chunk is an image prefill
393
- is_image_prefill = self.rbln_config.use_image_prefill and torch.all(
394
- token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
395
- )
396
- prefill_chunk_size = (
397
- self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
398
- )
399
-
400
- # Check if the prefill chunk is a text prefill which have image_tokens in it.
401
- is_text_prefill_with_image_tokens = (
402
- self.rbln_config.use_image_prefill
403
- and not is_image_prefill
404
- and torch.any(token_type_ids[:, step : step + prefill_chunk_size] == 1)
405
- )
406
-
407
- # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
408
- is_cross_block_boundary = (
409
- step // self.rbln_config.kvcache_block_size
410
- != (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
411
- )
412
-
413
- # Check if the prefill chunk is the last chunk
414
- is_last_chunk = step + prefill_chunk_size >= query_length
415
-
416
- if is_cross_block_boundary:
417
- padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
418
- padded_cache_lengths += padding_size
419
-
420
- # if text_prefill end with image_tokens, we only treat the text part.
421
- num_processed_tokens = prefill_chunk_size
422
- if is_text_prefill_with_image_tokens:
423
- first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
424
- num_processed_tokens = first_image_token_idx.item()
425
- if is_last_chunk:
426
- num_processed_tokens = query_length - step
427
-
428
- input_chunk = inputs[:, step : step + prefill_chunk_size]
429
- cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
430
- position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
431
- chunked_attention_mask[
432
- :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
433
- ] = 1
434
- query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
435
-
436
- if is_image_prefill:
437
- logits = self.image_prefill(
438
- input_chunk,
439
- cache_pos_chunk,
440
- block_tables,
441
- local_block_tables,
442
- query_position,
443
- chunked_attention_mask,
444
- position_ids_chunk,
445
- out=out_buffers,
446
- )
447
- else:
448
- logits = self.prefill(
449
- input_chunk,
450
- cache_pos_chunk,
451
- block_tables,
452
- local_block_tables,
453
- query_position,
454
- chunked_attention_mask,
455
- position_ids_chunk,
456
- out=out_buffers,
457
- )
458
-
459
- step += num_processed_tokens
460
-
461
- if not is_external_block_tables:
462
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
463
-
464
- return RBLNGemma3ForCausalLMOutput(
465
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
466
- )
467
-
468
- def decode_forward(
469
- self,
470
- inputs: torch.Tensor,
471
- cache_position: torch.Tensor = None,
472
- block_tables: torch.Tensor = None,
473
- is_external_block_tables: bool = None,
474
- attention_mask: Optional[torch.Tensor] = None,
475
- position_embed: Optional[torch.Tensor] = None,
476
- position_ids: Optional[torch.Tensor] = None,
477
- local_block_tables: Optional[torch.Tensor] = None,
478
- ) -> torch.FloatTensor:
479
- batch_size = inputs.shape[0]
480
- if batch_size != self.batch_size:
481
- raise RuntimeError(
482
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
483
- )
484
-
485
- if batch_size != cache_position.shape[0]:
486
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
487
-
488
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
489
- if is_external_block_tables:
490
- if attention_mask is None:
491
- raise ValueError("attention_mask should be provided with external block tables.")
492
- if local_block_tables is None:
493
- raise ValueError("local_block_tables should be provided with external block tables.")
494
- else:
495
- local_block_tables = (
496
- local_block_tables
497
- if local_block_tables is not None
498
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
499
- )
500
- if self.rbln_config.use_attention_mask and attention_mask is None:
501
- for b_idx in range(batch_size):
502
- decoding_step = cache_position[b_idx].item()
503
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
504
- raise ValueError(
505
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
506
- )
507
- self.dec_attn_mask[b_idx, decoding_step] = 1
508
-
509
- attention_mask = self.dec_attn_mask
510
-
511
- if self.batch_size < block_tables.shape[0]:
512
- block_tables = block_tables[: self.batch_size]
513
-
514
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
515
- attention_mask = attention_mask[: self.batch_size]
516
-
517
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
518
-
519
- return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
520
-
521
-
522
336
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
523
337
  """
524
338
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -532,52 +346,34 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
532
346
 
533
347
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
534
348
 
535
- def __post_init__(self, **kwargs):
536
- main_input_name = self.main_input_name
537
-
538
- if self.rbln_config.use_inputs_embeds:
539
- main_input_name = "inputs_embeds"
540
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
541
- self.embed_tokens = self._create_embedding_layer()
542
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
543
- else:
544
- self.embed_tokens = None
545
-
349
+ def setup_runtime(self):
546
350
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
547
351
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
548
- block_tables = torch.zeros(
549
- self.rbln_config.batch_size,
550
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
551
- dtype=torch.int16,
552
- ).fill_(-1)
553
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
352
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
353
+
354
+ common_kwargs = {
355
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
356
+ "embed_tokens": self.embed_tokens,
357
+ "dec_attn_mask": dec_attn_mask,
358
+ "page_table_manager": page_table_manager,
359
+ "rbln_config": self.rbln_config,
360
+ }
554
361
 
555
362
  self.prefill_decoder = RBLNGemma3RuntimeModel(
556
363
  runtime=self.model[0],
557
364
  image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
558
- main_input_name=main_input_name,
559
- embed_tokens=self.embed_tokens,
560
365
  phase="prefill",
561
366
  batch_size=self.rbln_config.batch_size,
562
- dec_attn_mask=dec_attn_mask,
563
- block_tables=block_tables,
564
- vocab_size=self.config.vocab_size,
565
- free_block_pool=free_block_pool,
566
- rbln_config=self.rbln_config,
367
+ **common_kwargs,
567
368
  )
568
369
 
569
370
  self.decoders = {}
570
371
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
571
372
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
572
373
  runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
573
- main_input_name=main_input_name,
574
- embed_tokens=self.embed_tokens,
575
374
  phase="decode",
576
375
  batch_size=batch_size,
577
- dec_attn_mask=dec_attn_mask,
578
- block_tables=block_tables,
579
- free_block_pool=free_block_pool,
580
- rbln_config=self.rbln_config,
376
+ **common_kwargs,
581
377
  )
582
378
 
583
379
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -633,6 +429,11 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
633
429
  raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
634
430
 
635
431
  if rbln_config.use_image_prefill:
432
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
433
+ raise NotImplementedError(
434
+ "Not implemented for different prefill chunk sizes between text and image prefill."
435
+ )
436
+
636
437
  # Update image prefill compile config
637
438
  img_prefill_input_info = cls.get_input_info(
638
439
  batch_size=1,
@@ -47,6 +47,8 @@ class RBLNGPT2Model(RBLNDecoderOnlyModel):
47
47
 
48
48
  A class to convert and run pre-trained transformers based GPT2Model model on RBLN devices.
49
49
  It implements the methods to convert a pre-trained transformers GPT2Model model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
50
52
  """
51
53
 
52
54
  _decoder_wrapper_cls = GPT2Wrapper
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -39,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
39
39
  batch_size: Optional[int] = None,
40
40
  vision_model: Optional[RBLNModelConfig] = None,
41
41
  text_model: Optional[RBLNModelConfig] = None,
42
- **kwargs: Dict[str, Any],
42
+ **kwargs: Any,
43
43
  ):
44
44
  """
45
45
  Args:
@@ -34,17 +34,11 @@ from transformers.models.idefics3.modeling_idefics3 import Idefics3CausalLMOutpu
34
34
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
- from ..decoderonly.modeling_decoderonly import (
38
- RBLNDecoderOnlyForCausalLMOutput,
39
- )
37
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
40
38
 
41
39
 
42
40
  if TYPE_CHECKING:
43
- from transformers import (
44
- AutoFeatureExtractor,
45
- AutoProcessor,
46
- AutoTokenizer,
47
- )
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
48
42
 
49
43
 
50
44
  class RBLNRuntimeVisionModel(RBLNPytorchRuntime):
@@ -494,7 +488,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
494
488
  if not return_dict:
495
489
  return logits, generate_idx
496
490
  else:
497
- return RBLNDecoderOnlyForCausalLMOutput(
491
+ return RBLNDecoderOnlyOutput(
498
492
  logits=logits,
499
493
  generate_idx=generate_idx,
500
494
  )
@@ -85,11 +85,20 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
85
85
 
86
86
  class RBLNLlamaModel(RBLNDecoderOnlyModel):
87
87
  """
88
- The Llama Model transformer with a language modeling head (linear layer) on top.
88
+ The Llama Model transformer outputting raw hidden-states without any specific head on top.
89
89
  This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
90
 
91
- A class to convert and run pre-trained transformers based LlamaModel model on RBLN devices.
92
- It implements the methods to convert a pre-trained transformers LlamaModel model into a RBLN transformer model by:
91
+ A class to convert and run pre-trained transformers based LlamaModel on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers LlamaModel into a RBLN transformer model by:
93
+
94
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
95
+ - compiling the resulting graph using the RBLN compiler.
96
+
97
+ **Configuration:**
98
+ This model uses [`RBLNLlamaModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
99
+ the `rbln_config` parameter should be an instance of [`RBLNLlamaModelConfig`] or a dictionary conforming to its structure.
100
+
101
+ See the [`RBLNLlamaModelConfig`] class for all available configuration options.
93
102
  """
94
103
 
95
104
  _decoder_wrapper_cls = LlamaWrapper
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -33,7 +33,7 @@ class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
33
33
  batch_size: Optional[int] = None,
34
34
  vision_tower: Optional[RBLNModelConfig] = None,
35
35
  language_model: Optional[RBLNModelConfig] = None,
36
- **kwargs: Dict[str, Any],
36
+ **kwargs: Any,
37
37
  ):
38
38
  """
39
39
  Args:
@@ -16,30 +16,20 @@ import inspect
16
16
  from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from transformers import (
20
- AutoModelForImageTextToText,
21
- LlavaForConditionalGeneration,
22
- PretrainedConfig,
23
- PreTrainedModel,
24
- )
19
+ from transformers import AutoModelForImageTextToText, LlavaForConditionalGeneration, PretrainedConfig, PreTrainedModel
25
20
  from transformers.modeling_outputs import BaseModelOutputWithPooling
26
21
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
27
22
 
28
23
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
29
24
  from ....modeling import RBLNModel
30
25
  from ....utils.logging import get_logger
31
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyForCausalLMOutput
26
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
32
27
 
33
28
 
34
29
  logger = get_logger(__name__)
35
30
 
36
31
  if TYPE_CHECKING:
37
- from transformers import (
38
- AutoFeatureExtractor,
39
- AutoProcessor,
40
- AutoTokenizer,
41
- PretrainedConfig,
42
- )
32
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
43
33
 
44
34
 
45
35
  class LoopVisionTower:
@@ -111,6 +101,55 @@ class LoopProjector:
111
101
 
112
102
 
113
103
  class RBLNLlavaForConditionalGeneration(RBLNModel):
104
+ """
105
+ RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
106
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
107
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
108
+ Important Note:
109
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
110
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
111
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNLlavaForConditionalGeneration class for details.
112
+ Examples:
113
+ ```python
114
+ from optimum.rbln import RBLNLlavaForConditionalGeneration
115
+ model = RBLNLlavaForConditionalGeneration.from_pretrained(
116
+ "llava-hf/llava-1.5-7b-hf",
117
+ export=True,
118
+ rbln_config={
119
+ "vision_tower": {"output_hidden_states": True},
120
+ "language_model": {
121
+ "tensor_parallel_size": 4,
122
+ "use_inputs_embeds": True, # In Llava, language model must use inputs_embeds as input.
123
+ },
124
+ },
125
+ )
126
+ model.save_pretrained("compiled-llava-1.5-7b-hf")
127
+
128
+ # Using a RBLNLlavaForConditionalGenerationConfig instance (recommended for type checking)
129
+ from optimum.rbln import RBLNLlavaForConditionalGenerationConfig
130
+ vision_config = RBLNCLIPVisionModelConfig(
131
+ batch_size=1,
132
+ output_hidden_states=True
133
+ )
134
+ language_model_config = RBLNLlamaForCausalLMConfig(
135
+ batch_size=1,
136
+ max_seq_len=4096,
137
+ use_inputs_embeds=True,
138
+ tensor_parallel_size=4
139
+ )
140
+ llava_config = RBLNLlavaForConditionalGenerationConfig(
141
+ batch_size=1,
142
+ vision_tower=vision_config,
143
+ language_model=language_model_config
144
+ )
145
+ model = RBLNLlavaForConditionalGeneration.from_pretrained(
146
+ "llava-hf/llava-1.5-7b-hf",
147
+ export=True,
148
+ rbln_config=llava_config
149
+ )
150
+ ```
151
+ """
152
+
114
153
  auto_model_class = AutoModelForImageTextToText
115
154
  _rbln_submodules = [
116
155
  {"name": "vision_tower"},
@@ -374,7 +413,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel):
374
413
  if not return_dict:
375
414
  return logits, generate_idx
376
415
  else:
377
- return RBLNDecoderOnlyForCausalLMOutput(
416
+ return RBLNDecoderOnlyOutput(
378
417
  logits=logits,
379
418
  generate_idx=generate_idx,
380
419
  )
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -38,7 +38,7 @@ class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
38
38
  batch_size: Optional[int] = None,
39
39
  vision_tower: Optional[RBLNModelConfig] = None,
40
40
  language_model: Optional[RBLNModelConfig] = None,
41
- **kwargs: Dict[str, Any],
41
+ **kwargs: Any,
42
42
  ):
43
43
  """
44
44
  Args: