optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +10 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -2
  4. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -2
  5. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -2
  6. optimum/rbln/modeling.py +53 -9
  7. optimum/rbln/modeling_base.py +22 -3
  8. optimum/rbln/transformers/__init__.py +10 -0
  9. optimum/rbln/transformers/modeling_generic.py +0 -19
  10. optimum/rbln/transformers/models/__init__.py +14 -0
  11. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  12. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  13. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
  14. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
  15. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  16. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  17. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  18. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
  19. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
  20. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
  21. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  22. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  23. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
  24. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
  25. optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
  26. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
  27. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -3
  28. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/METADATA +1 -1
  29. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/RECORD +31 -27
  30. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/WHEEL +0 -0
  31. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc1.dist-info}/licenses/LICENSE +0 -0
@@ -167,6 +167,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
167
167
  block_tables: Optional[torch.Tensor] = None,
168
168
  position_embed: Optional[torch.Tensor] = None,
169
169
  position_ids: Optional[torch.Tensor] = None,
170
+ token_type_ids: Optional[torch.Tensor] = None,
171
+ local_block_tables: Optional[torch.Tensor] = None,
170
172
  ):
171
173
  if input_ids is None and inputs_embeds is None:
172
174
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -193,6 +195,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
193
195
  attention_mask=attention_mask,
194
196
  position_embed=position_embed,
195
197
  position_ids=position_ids,
198
+ local_block_tables=local_block_tables,
196
199
  )
197
200
  else:
198
201
  return self.prefill_forward(
@@ -202,6 +205,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
202
205
  batch_idx,
203
206
  block_tables,
204
207
  position_embed=position_embed,
208
+ token_type_ids=token_type_ids,
209
+ local_block_tables=local_block_tables,
205
210
  )
206
211
 
207
212
  def decode_forward(
@@ -213,6 +218,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
213
218
  attention_mask: Optional[torch.Tensor] = None,
214
219
  position_embed: Optional[torch.Tensor] = None,
215
220
  position_ids: Optional[torch.Tensor] = None,
221
+ local_block_tables: Optional[torch.Tensor] = None,
216
222
  ) -> torch.FloatTensor:
217
223
  batch_size = inputs.shape[0]
218
224
  if batch_size != self.batch_size:
@@ -262,6 +268,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
262
268
  cache_position: torch.Tensor,
263
269
  attention_mask: Optional[torch.Tensor] = None,
264
270
  position_embed: Optional[torch.Tensor] = None,
271
+ local_block_tables: Optional[torch.Tensor] = None,
272
+ token_type_ids: Optional[torch.Tensor] = None,
265
273
  ):
266
274
  """
267
275
  Prepare inputs for prefill phase.
@@ -345,6 +353,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
345
353
  block_tables: torch.Tensor = None,
346
354
  is_external_block_tables: bool = None,
347
355
  position_embed: Optional[torch.Tensor] = None,
356
+ local_block_tables: Optional[torch.Tensor] = None,
357
+ token_type_ids: Optional[torch.Tensor] = None,
348
358
  ) -> torch.FloatTensor:
349
359
  """
350
360
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -360,7 +370,9 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
360
370
  position_embed,
361
371
  padded_cache_lengths,
362
372
  query_length,
363
- ) = self._prepare_prefill_inputs(inputs, cache_position, attention_mask, position_embed)
373
+ ) = self._prepare_prefill_inputs(
374
+ inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
375
+ )
364
376
 
365
377
  # Process input in chunks of size `prefill_chunk_size`
366
378
  for step in range(0, query_length, self.prefill_chunk_size):
@@ -373,7 +385,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
373
385
  if position_embed is not None:
374
386
  position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
375
387
 
376
- if self.use_attention_mask:
388
+ if self.use_attention_mask and not self.use_position_ids:
377
389
  # Update attention mask to ensure proper causal behavior
378
390
  if step >= self.prefill_chunk_size:
379
391
  chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
@@ -387,10 +399,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
387
399
  input_chunk,
388
400
  cache_pos_chunk,
389
401
  block_tables,
402
+ position_embed_chunk if position_embed is not None else None,
390
403
  query_position,
391
404
  chunked_attention_mask if self.use_attention_mask else None,
392
- position_ids_chunk if position_ids is not None else None,
393
- position_embed_chunk if position_embed is not None else None,
405
+ position_ids_chunk if self.use_position_ids else None,
394
406
  out=out_buffers,
395
407
  )
396
408
 
@@ -440,12 +452,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
440
452
  if self.rbln_config.use_inputs_embeds:
441
453
  main_input_name = "inputs_embeds"
442
454
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
443
- with no_init_weights():
444
- self.embed_tokens = torch.nn.Embedding(
445
- self.config.vocab_size,
446
- self.config.hidden_size,
447
- self.config.pad_token_id,
448
- )
455
+ self.embed_tokens = self._create_embedding_layer()
449
456
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
450
457
  else:
451
458
  self.embed_tokens = None
@@ -478,6 +485,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
478
485
  attn_impl=self.rbln_config.attn_impl,
479
486
  use_position_ids=self.rbln_config.use_position_ids,
480
487
  )
488
+
481
489
  self.decoders = {}
482
490
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
483
491
  self.decoders[batch_size] = RBLNRuntimeModel(
@@ -515,6 +523,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
515
523
  save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
516
524
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
517
525
 
526
+ def _create_embedding_layer(self):
527
+ with no_init_weights():
528
+ embed_tokens = torch.nn.Embedding(
529
+ self.config.vocab_size,
530
+ self.config.hidden_size,
531
+ self.config.pad_token_id,
532
+ )
533
+ return embed_tokens
534
+
518
535
  def get_input_embeddings(self):
519
536
  return self.embed_tokens
520
537
 
@@ -1101,6 +1118,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1101
1118
  generate_idx: Optional[torch.Tensor] = None,
1102
1119
  padded_cache_lengths: Optional[torch.Tensor] = None,
1103
1120
  position_ids: Optional[torch.Tensor] = None,
1121
+ token_type_ids: Optional[torch.Tensor] = None,
1104
1122
  return_dict: Optional[torch.Tensor] = None,
1105
1123
  **kwargs,
1106
1124
  ) -> Tuple[torch.FloatTensor]:
@@ -1123,6 +1141,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1123
1141
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
1124
1142
  cache_position=cache_position,
1125
1143
  batch_idx=b_idx,
1144
+ token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
1126
1145
  )
1127
1146
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
1128
1147
  logits.append(output.logits)
@@ -41,7 +41,10 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
41
41
  for layer in causal_lm.transformer.h:
42
42
  if self.attn_impl == "eager":
43
43
  new_self_attn = ExaoneAttention(
44
- layer.attn.attention, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
44
+ layer.attn.attention,
45
+ self.use_attention_mask,
46
+ kvcache_block_size=self.kvcache_block_size,
47
+ use_position_ids=self.use_position_ids,
45
48
  )
46
49
  elif self.attn_impl == "flash_attn":
47
50
  new_self_attn = ExaoneFlashAttention(
@@ -49,6 +52,7 @@ class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
49
52
  kvcache_partition_len=self.kvcache_partition_len,
50
53
  use_attention_mask=self.use_attention_mask,
51
54
  kvcache_block_size=self.kvcache_block_size,
55
+ use_position_ids=self.use_position_ids,
52
56
  )
53
57
  else:
54
58
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -34,7 +34,10 @@ class GemmaWrapper(DecoderOnlyWrapper):
34
34
  for layer in causal_lm.model.layers:
35
35
  if self.attn_impl == "eager":
36
36
  new_self_attn = DecoderOnlyAttention(
37
- layer.self_attn, self.use_attention_mask, kvcache_block_size=self.kvcache_block_size
37
+ layer.self_attn,
38
+ self.use_attention_mask,
39
+ kvcache_block_size=self.kvcache_block_size,
40
+ use_position_ids=self.use_position_ids,
38
41
  )
39
42
  elif self.attn_impl == "flash_attn":
40
43
  new_self_attn = DecoderOnlyFlashAttention(
@@ -42,6 +45,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
42
45
  kvcache_partition_len=self.kvcache_partition_len,
43
46
  use_attention_mask=self.use_attention_mask,
44
47
  kvcache_block_size=self.kvcache_block_size,
48
+ use_position_ids=self.use_position_ids,
45
49
  )
46
50
  else:
47
51
  raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig, RBLNGemma3ForConditionalGenerationConfig
16
+ from .modeling_gemma3 import RBLNGemma3ForCausalLM, RBLNGemma3ForConditionalGeneration
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
19
+ from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
20
+
21
+
22
+ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
23
+ def __init__(
24
+ self,
25
+ prefill_chunk_size: Optional[int] = None,
26
+ use_position_ids: Optional[bool] = None,
27
+ use_attention_mask: Optional[bool] = None,
28
+ **kwargs,
29
+ ):
30
+ # use_attention_mask and use_position_ids are always True for Gemma3
31
+ use_attention_mask = use_attention_mask or True
32
+ use_position_ids = use_position_ids or True
33
+ prefill_chunk_size = prefill_chunk_size or 256
34
+
35
+ super().__init__(
36
+ prefill_chunk_size=prefill_chunk_size,
37
+ use_attention_mask=use_attention_mask,
38
+ use_position_ids=use_position_ids,
39
+ **kwargs,
40
+ )
41
+
42
+
43
+ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
44
+ submodules = ["vision_tower", "language_model"]
45
+
46
+ def __init__(
47
+ self,
48
+ batch_size: Optional[int] = None,
49
+ vision_tower: Optional[RBLNModelConfig] = None,
50
+ language_model: Optional[RBLNModelConfig] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Args:
55
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
56
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
57
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
58
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
59
+
60
+ Raises:
61
+ ValueError: If batch_size is not a positive integer.
62
+ """
63
+ super().__init__(**kwargs)
64
+ self.batch_size = batch_size or 1
65
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
66
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
67
+
68
+ self.vision_tower = self.init_submodule_config(RBLNSiglipVisionModelConfig, vision_tower)
69
+ self.language_model = self.init_submodule_config(RBLNGemma3ForCausalLMConfig, language_model)