optimum-rbln 0.7.4a5__py3-none-any.whl → 0.7.4a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -163,6 +163,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
163
163
  attention_mask: Optional[torch.Tensor] = None,
164
164
  batch_idx: Optional[int] = None,
165
165
  block_tables: Optional[torch.Tensor] = None,
166
+ position_embed: Optional[torch.Tensor] = None,
166
167
  ):
167
168
  if input_ids is None and inputs_embeds is None:
168
169
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -187,9 +188,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
187
188
  block_tables,
188
189
  is_external_block_tables,
189
190
  attention_mask=attention_mask,
191
+ position_embed=position_embed,
190
192
  )
191
193
  else:
192
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
194
+ return self.prefill_forward(
195
+ inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
196
+ )
193
197
 
194
198
  def decode_forward(
195
199
  self,
@@ -198,6 +202,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
198
202
  block_tables: torch.Tensor = None,
199
203
  is_external_block_tables: bool = None,
200
204
  attention_mask: Optional[torch.Tensor] = None,
205
+ position_embed: Optional[torch.Tensor] = None,
201
206
  ) -> torch.FloatTensor:
202
207
  batch_size = inputs.shape[0]
203
208
  if batch_size != self.batch_size:
@@ -229,6 +234,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
229
234
  cache_position,
230
235
  attention_mask if self.use_attention_mask else None,
231
236
  block_tables,
237
+ position_embed,
232
238
  )
233
239
 
234
240
  return logits
@@ -241,6 +247,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
241
247
  batch_idx: int = None,
242
248
  block_tables: torch.Tensor = None,
243
249
  is_external_block_tables: bool = None,
250
+ position_embed: Optional[torch.Tensor] = None,
244
251
  ) -> torch.FloatTensor:
245
252
  """
246
253
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -251,6 +258,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
251
258
  # Handle continuous batching in a compiled graph by extracting valid inputs
252
259
  # If an attention mask is provided, select only the valid (non-masked) inputs
253
260
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
261
+ if position_embed is not None:
262
+ position_embed = (
263
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
264
+ )
254
265
 
255
266
  query_length = inputs.shape[1]
256
267
  if query_length > self.max_seq_len:
@@ -295,9 +306,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
295
306
  dim=-1,
296
307
  )
297
308
 
309
+ if position_embed is not None:
310
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
311
+
298
312
  # Extract the current chunk of inputs and cache positions
299
313
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
300
314
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
315
+ if position_embed is not None:
316
+ position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
301
317
 
302
318
  if self.use_attention_mask:
303
319
  # Update attention mask to ensure proper causal behavior
@@ -315,6 +331,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
315
331
  chunked_attention_mask if self.use_attention_mask else None,
316
332
  query_position,
317
333
  block_tables,
334
+ position_embed_chunk if position_embed is not None else None,
318
335
  out=out_buffers,
319
336
  )
320
337
 
@@ -434,6 +451,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
434
451
  def get_input_embeddings(self):
435
452
  return self.embed_tokens
436
453
 
454
+ def get_attn_impl(self) -> str:
455
+ return self.rbln_config.attn_impl
456
+
457
+ def get_kvcache_num_blocks(self) -> int:
458
+ return self.rbln_config.kvcache_num_blocks
459
+
437
460
  @classmethod
438
461
  def get_quantized_model(
439
462
  cls,
@@ -565,7 +588,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
565
588
  finally:
566
589
  torch.nn.functional.linear = original_linear
567
590
 
568
- return compile_model(quantize_config=rbln_config.quantization)
591
+ compiled_models = compile_model(quantize_config=rbln_config.quantization)
592
+
593
+ # check if the memory is enough to have additional blocks
594
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
595
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
596
+ cls.maybe_suggest_kvcache_num_blocks(
597
+ compiled_models=compiled_models,
598
+ model_config=model.config,
599
+ rbln_config=rbln_config,
600
+ )
601
+
602
+ return compiled_models
603
+
604
+ @classmethod
605
+ def maybe_suggest_kvcache_num_blocks(
606
+ cls,
607
+ compiled_models: Dict[str, rebel.RBLNCompiledModel],
608
+ model_config: PretrainedConfig,
609
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
610
+ ) -> None:
611
+ # Get the actual memory allocation of each node by key
612
+ alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
613
+ alloc_memory_by_key: Dict[str, int] = {
614
+ key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
615
+ }
616
+ for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
617
+ alloc_memory_by_key[key] += sum(memory_per_node)
618
+ alloc_memory_by_key.pop("PortRecur") # kv-cache
619
+ kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
620
+
621
+ # Get the maximum number of blocks that can be allocated
622
+ buffer = sum(alloc_memory_by_key.values())
623
+ max_num_blocks = cls.get_maximum_num_blocks(
624
+ config=model_config,
625
+ tensor_parallel_size=rbln_config.tensor_parallel_size,
626
+ kvcache_block_size=rbln_config.kvcache_block_size,
627
+ kernel_size=kernel_size,
628
+ buffer=buffer,
629
+ )
630
+
631
+ # Since our estimation logic is not always accurate,
632
+ # users can set `kvcache_num_blocks` to `max_num_blocks`.
633
+ # If the memory is not enough, the model will fail to compile.
634
+ if rbln_config.kvcache_num_blocks < max_num_blocks:
635
+ logger.warning(
636
+ f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
637
+ "Our analysis indicates that additional memory is available for more blocks. "
638
+ f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
639
+ "Please be advised that our memory estimation algorithm has limitations, "
640
+ "and increasing this value may not guarantee successful model compilation."
641
+ )
569
642
 
570
643
  @classmethod
571
644
  def get_maximum_num_blocks(
@@ -573,8 +646,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
573
646
  config: PretrainedConfig,
574
647
  tensor_parallel_size: int,
575
648
  kvcache_block_size: int,
576
- nbits_per_param: int,
577
- n_model_params: int,
649
+ nbits_per_param: Optional[int] = None,
650
+ n_model_params: Optional[int] = None,
651
+ kernel_size: Optional[int] = None,
652
+ buffer: Optional[int] = None,
578
653
  ) -> int:
579
654
  """
580
655
  We are finding max_n_blocks(x) that satisfies the following equation:
@@ -624,24 +699,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
624
699
  ATOM_SYS_DRAM_NBYTES = 288 * 2**20
625
700
  available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
626
701
 
627
- # Get estimated kernel size (approximated)
628
- lm_heads_params = align(vocab_size, 64) * hidden_size
629
- lm_heads_nbytes = (
630
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
631
- )
632
- params = n_model_params - lm_heads_params
633
- layer_nbytes = (
634
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
635
- * num_layers
636
- * tensor_parallel_size
637
- )
638
- kernel_size = layer_nbytes + lm_heads_nbytes
702
+ if kernel_size is None:
703
+ if n_model_params is None:
704
+ raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
705
+ # Get estimated kernel size (approximated)
706
+ lm_heads_params = align(vocab_size, 64) * hidden_size
707
+ lm_heads_nbytes = (
708
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
709
+ )
710
+ params = n_model_params - lm_heads_params
711
+ layer_nbytes = (
712
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
713
+ * num_layers
714
+ * tensor_parallel_size
715
+ )
716
+ kernel_size = layer_nbytes + lm_heads_nbytes
717
+ elif n_model_params is not None:
718
+ raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
639
719
 
640
720
  available_dram -= kernel_size
641
721
 
642
- # TODO: Accurate buffer estimation
643
- buffer_per_core = 2**29 # 500MB per npu
644
- buffer = buffer_per_core * tensor_parallel_size
722
+ if buffer is None:
723
+ # TODO: Accurate buffer estimation
724
+ buffer_per_core = 2**29 # 500MB per npu
725
+ buffer = buffer_per_core * tensor_parallel_size
645
726
  available_dram -= buffer
646
727
 
647
728
  b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
@@ -651,6 +732,74 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
651
732
 
652
733
  return max_n_blocks
653
734
 
735
+ @classmethod
736
+ def get_input_info(
737
+ cls,
738
+ batch_size: int,
739
+ query_length: int,
740
+ use_inputs_embeds: bool,
741
+ use_attention_mask: bool,
742
+ max_seq_len: int,
743
+ kvcache_block_size: int,
744
+ kvcache_num_blocks: int,
745
+ num_key_value_heads: int,
746
+ num_hidden_layers: int,
747
+ hidden_size: int,
748
+ head_dim: int,
749
+ ):
750
+ if use_inputs_embeds:
751
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
752
+ else:
753
+ main_input = ("input_ids", [batch_size, query_length], "int64")
754
+
755
+ input_info = [
756
+ main_input,
757
+ (
758
+ "cache_position",
759
+ [batch_size, query_length],
760
+ "int32",
761
+ ),
762
+ ]
763
+
764
+ if use_attention_mask:
765
+ input_info.extend(
766
+ [
767
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
768
+ ]
769
+ )
770
+
771
+ if query_length > 1:
772
+ input_info.extend(
773
+ [
774
+ ("query_position", [], "int16"),
775
+ ]
776
+ )
777
+
778
+ max_block_cnt = max_seq_len // kvcache_block_size
779
+
780
+ if query_length > 1:
781
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
782
+ else:
783
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
784
+
785
+ input_info.extend(
786
+ [
787
+ (
788
+ f"past_key_values_{i}",
789
+ [
790
+ kvcache_num_blocks,
791
+ num_key_value_heads,
792
+ kvcache_block_size,
793
+ head_dim,
794
+ ],
795
+ "float32",
796
+ )
797
+ for i in range(num_hidden_layers * 2)
798
+ ]
799
+ )
800
+
801
+ return input_info
802
+
654
803
  @classmethod
655
804
  def _update_rbln_config(
656
805
  cls,
@@ -680,120 +829,70 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
680
829
  max_seq_len=rbln_config.max_seq_len,
681
830
  )
682
831
 
683
- rbln_config.kvcache_num_blocks = (
684
- rbln_config.max_seq_len // rbln_config.kvcache_block_size
685
- ) * rbln_config.batch_size
832
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
833
+ max_num_blocks = required_num_blocks
686
834
 
687
835
  if rbln_config.attn_impl == "flash_attn":
688
- max_num_blocks = cls.get_maximum_num_blocks(
836
+ estimated_max_num_blocks = cls.get_maximum_num_blocks(
689
837
  config=model_config,
690
838
  tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
691
839
  kvcache_block_size=rbln_config.kvcache_block_size,
692
840
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
693
841
  n_model_params=sum(p.numel() for p in model.parameters()),
694
842
  )
695
- rbln_config.kvcache_num_blocks = min(rbln_config.kvcache_num_blocks, max_num_blocks)
696
843
 
697
- required_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
698
- if rbln_config.kvcache_num_blocks < required_blocks:
699
- rbln_config.kvcache_num_blocks = required_blocks
844
+ max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
700
845
 
701
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
846
+ flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
847
+ if max_num_blocks < flash_min_blocks:
848
+ max_num_blocks = flash_min_blocks
702
849
 
703
- if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
850
+ if max_num_blocks < rbln_config.batch_size:
704
851
  raise RuntimeError(
705
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({rbln_config.kvcache_num_blocks}). "
852
+ f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
706
853
  "Ensure the number of blocks is at least equal to the batch size."
707
854
  )
708
855
 
856
+ if rbln_config.kvcache_num_blocks is None:
857
+ rbln_config.kvcache_num_blocks = max_num_blocks
858
+ elif rbln_config.kvcache_num_blocks > max_num_blocks:
859
+ logger.warning(
860
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
861
+ f" than the estimated maximum number of blocks ({max_num_blocks})."
862
+ "This can cause a failure during model compilation."
863
+ )
864
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
709
865
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
710
866
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
711
867
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
712
- head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
713
868
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
869
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
714
870
 
715
- def get_input_info(
716
- batch_size,
717
- query_length,
718
- use_inputs_embeds,
719
- hidden_size,
720
- use_attention_mask,
721
- max_seq_len,
722
- kvcache_block_size,
723
- kvcache_num_blocks,
724
- ):
725
- if use_inputs_embeds:
726
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
727
- else:
728
- main_input = ("input_ids", [batch_size, query_length], "int64")
729
-
730
- input_info = [
731
- main_input,
732
- (
733
- "cache_position",
734
- [batch_size, query_length],
735
- "int32",
736
- ),
737
- ]
738
-
739
- if use_attention_mask:
740
- input_info.extend(
741
- [
742
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
743
- ]
744
- )
745
-
746
- if query_length > 1:
747
- input_info.extend(
748
- [
749
- ("query_position", [], "int16"),
750
- ]
751
- )
752
-
753
- max_block_cnt = max_seq_len // kvcache_block_size
754
-
755
- if query_length > 1:
756
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
757
- else:
758
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
759
-
760
- input_info.extend(
761
- [
762
- (
763
- f"past_key_values_{i}",
764
- [
765
- kvcache_num_blocks,
766
- num_key_value_heads,
767
- kvcache_block_size,
768
- head_dim,
769
- ],
770
- "float32",
771
- )
772
- for i in range(num_hidden_layers * 2)
773
- ]
774
- )
775
-
776
- return input_info
777
-
778
- prefill_input_info = get_input_info(
871
+ prefill_input_info = cls.get_input_info(
779
872
  batch_size=1,
780
873
  query_length=rbln_config.prefill_chunk_size,
781
874
  use_inputs_embeds=rbln_config.use_inputs_embeds,
782
- hidden_size=hidden_size,
783
875
  use_attention_mask=rbln_config.use_attention_mask,
784
876
  max_seq_len=rbln_config.max_seq_len,
785
877
  kvcache_block_size=rbln_config.kvcache_block_size,
786
878
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
879
+ num_key_value_heads=num_key_value_heads,
880
+ num_hidden_layers=num_hidden_layers,
881
+ hidden_size=hidden_size,
882
+ head_dim=head_dim,
787
883
  )
788
- dec_input_info = get_input_info(
884
+ dec_input_info = cls.get_input_info(
789
885
  batch_size=rbln_config.batch_size,
790
886
  query_length=1,
791
887
  use_inputs_embeds=rbln_config.use_inputs_embeds,
792
- hidden_size=hidden_size,
793
888
  use_attention_mask=rbln_config.use_attention_mask,
794
889
  max_seq_len=rbln_config.max_seq_len,
795
890
  kvcache_block_size=rbln_config.kvcache_block_size,
796
891
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
892
+ num_key_value_heads=num_key_value_heads,
893
+ num_hidden_layers=num_hidden_layers,
894
+ hidden_size=hidden_size,
895
+ head_dim=head_dim,
797
896
  )
798
897
 
799
898
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
@@ -864,7 +963,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
864
963
  model_inputs.update({"inputs_embeds": inputs_embeds})
865
964
  else:
866
965
  raise ValueError(
867
- "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
966
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
868
967
  )
869
968
  else:
870
969
  model_inputs.update({"input_ids": input_ids})
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_idefics3 import RBLNIdefics3ForConditionalGenerationConfig, RBLNIdefics3VisionTransformerConfig
16
+ from .modeling_idefics3 import RBLNIdefics3ForConditionalGeneration, RBLNIdefics3VisionTransformer
@@ -0,0 +1,51 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
21
+ pass
22
+
23
+
24
+ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
25
+ submodules = ["vision_model", "text_model"]
26
+
27
+ def __init__(
28
+ self,
29
+ batch_size: Optional[int] = None,
30
+ vision_model: Optional[RBLNModelConfig] = None,
31
+ text_model: Optional[RBLNModelConfig] = None,
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Args:
36
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
37
+ vision_model (Optional[RBLNModelConfig]): Configuration for the vision transformer component.
38
+ text_model (Optional[RBLNModelConfig]): Configuration for the text model component.
39
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
40
+
41
+ Raises:
42
+ ValueError: If batch_size is not a positive integer.
43
+ """
44
+
45
+ super().__init__(**kwargs)
46
+ self.batch_size = batch_size or 1
47
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
48
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
49
+
50
+ self.vision_model = vision_model
51
+ self.text_model = text_model