optimum-rbln 0.7.4a5__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -163,6 +163,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
163
163
  attention_mask: Optional[torch.Tensor] = None,
164
164
  batch_idx: Optional[int] = None,
165
165
  block_tables: Optional[torch.Tensor] = None,
166
+ position_embed: Optional[torch.Tensor] = None,
166
167
  ):
167
168
  if input_ids is None and inputs_embeds is None:
168
169
  raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
@@ -187,9 +188,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
187
188
  block_tables,
188
189
  is_external_block_tables,
189
190
  attention_mask=attention_mask,
191
+ position_embed=position_embed,
190
192
  )
191
193
  else:
192
- return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx, block_tables)
194
+ return self.prefill_forward(
195
+ inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
196
+ )
193
197
 
194
198
  def decode_forward(
195
199
  self,
@@ -198,6 +202,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
198
202
  block_tables: torch.Tensor = None,
199
203
  is_external_block_tables: bool = None,
200
204
  attention_mask: Optional[torch.Tensor] = None,
205
+ position_embed: Optional[torch.Tensor] = None,
201
206
  ) -> torch.FloatTensor:
202
207
  batch_size = inputs.shape[0]
203
208
  if batch_size != self.batch_size:
@@ -229,6 +234,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
229
234
  cache_position,
230
235
  attention_mask if self.use_attention_mask else None,
231
236
  block_tables,
237
+ position_embed,
232
238
  )
233
239
 
234
240
  return logits
@@ -241,6 +247,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
241
247
  batch_idx: int = None,
242
248
  block_tables: torch.Tensor = None,
243
249
  is_external_block_tables: bool = None,
250
+ position_embed: Optional[torch.Tensor] = None,
244
251
  ) -> torch.FloatTensor:
245
252
  """
246
253
  Performs chunked prefill for efficient KV-cache updates and memory optimization.
@@ -251,6 +258,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
251
258
  # Handle continuous batching in a compiled graph by extracting valid inputs
252
259
  # If an attention mask is provided, select only the valid (non-masked) inputs
253
260
  inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
261
+ if position_embed is not None:
262
+ position_embed = (
263
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
264
+ )
254
265
 
255
266
  query_length = inputs.shape[1]
256
267
  if query_length > self.max_seq_len:
@@ -295,9 +306,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
295
306
  dim=-1,
296
307
  )
297
308
 
309
+ if position_embed is not None:
310
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
311
+
298
312
  # Extract the current chunk of inputs and cache positions
299
313
  input_chunk = inputs[:, step : step + self.prefill_chunk_size]
300
314
  cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
315
+ if position_embed is not None:
316
+ position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
301
317
 
302
318
  if self.use_attention_mask:
303
319
  # Update attention mask to ensure proper causal behavior
@@ -315,6 +331,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
315
331
  chunked_attention_mask if self.use_attention_mask else None,
316
332
  query_position,
317
333
  block_tables,
334
+ position_embed_chunk if position_embed is not None else None,
318
335
  out=out_buffers,
319
336
  )
320
337
 
@@ -565,7 +582,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
565
582
  finally:
566
583
  torch.nn.functional.linear = original_linear
567
584
 
568
- return compile_model(quantize_config=rbln_config.quantization)
585
+ compiled_models = compile_model(quantize_config=rbln_config.quantization)
586
+
587
+ # check if the memory is enough to have additional blocks
588
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
589
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
590
+ cls.maybe_suggest_kvcache_num_blocks(
591
+ compiled_models=compiled_models,
592
+ model_config=model.config,
593
+ rbln_config=rbln_config,
594
+ )
595
+
596
+ return compiled_models
597
+
598
+ @classmethod
599
+ def maybe_suggest_kvcache_num_blocks(
600
+ cls,
601
+ compiled_models: Dict[str, rebel.RBLNCompiledModel],
602
+ model_config: PretrainedConfig,
603
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
604
+ ) -> None:
605
+ # Get the actual memory allocation of each node by key
606
+ alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
607
+ alloc_memory_by_key: Dict[str, int] = {
608
+ key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
609
+ }
610
+ for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
611
+ alloc_memory_by_key[key] += sum(memory_per_node)
612
+ alloc_memory_by_key.pop("PortRecur") # kv-cache
613
+ kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
614
+
615
+ # Get the maximum number of blocks that can be allocated
616
+ buffer = sum(alloc_memory_by_key.values())
617
+ max_num_blocks = cls.get_maximum_num_blocks(
618
+ config=model_config,
619
+ tensor_parallel_size=rbln_config.tensor_parallel_size,
620
+ kvcache_block_size=rbln_config.kvcache_block_size,
621
+ kernel_size=kernel_size,
622
+ buffer=buffer,
623
+ )
624
+
625
+ # Since our estimation logic is not always accurate,
626
+ # users can set `kvcache_num_blocks` to `max_num_blocks`.
627
+ # If the memory is not enough, the model will fail to compile.
628
+ if rbln_config.kvcache_num_blocks < max_num_blocks:
629
+ logger.warning(
630
+ f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
631
+ "Our analysis indicates that additional memory is available for more blocks. "
632
+ f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
633
+ "Please be advised that our memory estimation algorithm has limitations, "
634
+ "and increasing this value may not guarantee successful model compilation."
635
+ )
569
636
 
570
637
  @classmethod
571
638
  def get_maximum_num_blocks(
@@ -573,8 +640,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
573
640
  config: PretrainedConfig,
574
641
  tensor_parallel_size: int,
575
642
  kvcache_block_size: int,
576
- nbits_per_param: int,
577
- n_model_params: int,
643
+ nbits_per_param: Optional[int] = None,
644
+ n_model_params: Optional[int] = None,
645
+ kernel_size: Optional[int] = None,
646
+ buffer: Optional[int] = None,
578
647
  ) -> int:
579
648
  """
580
649
  We are finding max_n_blocks(x) that satisfies the following equation:
@@ -624,24 +693,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
624
693
  ATOM_SYS_DRAM_NBYTES = 288 * 2**20
625
694
  available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
626
695
 
627
- # Get estimated kernel size (approximated)
628
- lm_heads_params = align(vocab_size, 64) * hidden_size
629
- lm_heads_nbytes = (
630
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
631
- )
632
- params = n_model_params - lm_heads_params
633
- layer_nbytes = (
634
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
635
- * num_layers
636
- * tensor_parallel_size
637
- )
638
- kernel_size = layer_nbytes + lm_heads_nbytes
696
+ if kernel_size is None:
697
+ if n_model_params is None:
698
+ raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
699
+ # Get estimated kernel size (approximated)
700
+ lm_heads_params = align(vocab_size, 64) * hidden_size
701
+ lm_heads_nbytes = (
702
+ align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
703
+ )
704
+ params = n_model_params - lm_heads_params
705
+ layer_nbytes = (
706
+ align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
707
+ * num_layers
708
+ * tensor_parallel_size
709
+ )
710
+ kernel_size = layer_nbytes + lm_heads_nbytes
711
+ elif n_model_params is not None:
712
+ raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
639
713
 
640
714
  available_dram -= kernel_size
641
715
 
642
- # TODO: Accurate buffer estimation
643
- buffer_per_core = 2**29 # 500MB per npu
644
- buffer = buffer_per_core * tensor_parallel_size
716
+ if buffer is None:
717
+ # TODO: Accurate buffer estimation
718
+ buffer_per_core = 2**29 # 500MB per npu
719
+ buffer = buffer_per_core * tensor_parallel_size
645
720
  available_dram -= buffer
646
721
 
647
722
  b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
@@ -651,6 +726,74 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
651
726
 
652
727
  return max_n_blocks
653
728
 
729
+ @classmethod
730
+ def get_input_info(
731
+ cls,
732
+ batch_size: int,
733
+ query_length: int,
734
+ use_inputs_embeds: bool,
735
+ use_attention_mask: bool,
736
+ max_seq_len: int,
737
+ kvcache_block_size: int,
738
+ kvcache_num_blocks: int,
739
+ num_key_value_heads: int,
740
+ num_hidden_layers: int,
741
+ hidden_size: int,
742
+ head_dim: int,
743
+ ):
744
+ if use_inputs_embeds:
745
+ main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
746
+ else:
747
+ main_input = ("input_ids", [batch_size, query_length], "int64")
748
+
749
+ input_info = [
750
+ main_input,
751
+ (
752
+ "cache_position",
753
+ [batch_size, query_length],
754
+ "int32",
755
+ ),
756
+ ]
757
+
758
+ if use_attention_mask:
759
+ input_info.extend(
760
+ [
761
+ ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
762
+ ]
763
+ )
764
+
765
+ if query_length > 1:
766
+ input_info.extend(
767
+ [
768
+ ("query_position", [], "int16"),
769
+ ]
770
+ )
771
+
772
+ max_block_cnt = max_seq_len // kvcache_block_size
773
+
774
+ if query_length > 1:
775
+ input_info.extend([("block_tables", [max_block_cnt], "int16")])
776
+ else:
777
+ input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
778
+
779
+ input_info.extend(
780
+ [
781
+ (
782
+ f"past_key_values_{i}",
783
+ [
784
+ kvcache_num_blocks,
785
+ num_key_value_heads,
786
+ kvcache_block_size,
787
+ head_dim,
788
+ ],
789
+ "float32",
790
+ )
791
+ for i in range(num_hidden_layers * 2)
792
+ ]
793
+ )
794
+
795
+ return input_info
796
+
654
797
  @classmethod
655
798
  def _update_rbln_config(
656
799
  cls,
@@ -680,120 +823,70 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
680
823
  max_seq_len=rbln_config.max_seq_len,
681
824
  )
682
825
 
683
- rbln_config.kvcache_num_blocks = (
684
- rbln_config.max_seq_len // rbln_config.kvcache_block_size
685
- ) * rbln_config.batch_size
826
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
827
+ max_num_blocks = required_num_blocks
686
828
 
687
829
  if rbln_config.attn_impl == "flash_attn":
688
- max_num_blocks = cls.get_maximum_num_blocks(
830
+ estimated_max_num_blocks = cls.get_maximum_num_blocks(
689
831
  config=model_config,
690
832
  tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
691
833
  kvcache_block_size=rbln_config.kvcache_block_size,
692
834
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
693
835
  n_model_params=sum(p.numel() for p in model.parameters()),
694
836
  )
695
- rbln_config.kvcache_num_blocks = min(rbln_config.kvcache_num_blocks, max_num_blocks)
696
837
 
697
- required_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
698
- if rbln_config.kvcache_num_blocks < required_blocks:
699
- rbln_config.kvcache_num_blocks = required_blocks
838
+ max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
700
839
 
701
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
840
+ flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
841
+ if max_num_blocks < flash_min_blocks:
842
+ max_num_blocks = flash_min_blocks
702
843
 
703
- if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
844
+ if max_num_blocks < rbln_config.batch_size:
704
845
  raise RuntimeError(
705
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({rbln_config.kvcache_num_blocks}). "
846
+ f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
706
847
  "Ensure the number of blocks is at least equal to the batch size."
707
848
  )
708
849
 
850
+ if rbln_config.kvcache_num_blocks is None:
851
+ rbln_config.kvcache_num_blocks = max_num_blocks
852
+ elif rbln_config.kvcache_num_blocks > max_num_blocks:
853
+ logger.warning(
854
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
855
+ f" than the estimated maximum number of blocks ({max_num_blocks})."
856
+ "This can cause a failure during model compilation."
857
+ )
858
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
709
859
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
710
860
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
711
861
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
712
- head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
713
862
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
863
+ head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
714
864
 
715
- def get_input_info(
716
- batch_size,
717
- query_length,
718
- use_inputs_embeds,
719
- hidden_size,
720
- use_attention_mask,
721
- max_seq_len,
722
- kvcache_block_size,
723
- kvcache_num_blocks,
724
- ):
725
- if use_inputs_embeds:
726
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
727
- else:
728
- main_input = ("input_ids", [batch_size, query_length], "int64")
729
-
730
- input_info = [
731
- main_input,
732
- (
733
- "cache_position",
734
- [batch_size, query_length],
735
- "int32",
736
- ),
737
- ]
738
-
739
- if use_attention_mask:
740
- input_info.extend(
741
- [
742
- ("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
743
- ]
744
- )
745
-
746
- if query_length > 1:
747
- input_info.extend(
748
- [
749
- ("query_position", [], "int16"),
750
- ]
751
- )
752
-
753
- max_block_cnt = max_seq_len // kvcache_block_size
754
-
755
- if query_length > 1:
756
- input_info.extend([("block_tables", [max_block_cnt], "int16")])
757
- else:
758
- input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
759
-
760
- input_info.extend(
761
- [
762
- (
763
- f"past_key_values_{i}",
764
- [
765
- kvcache_num_blocks,
766
- num_key_value_heads,
767
- kvcache_block_size,
768
- head_dim,
769
- ],
770
- "float32",
771
- )
772
- for i in range(num_hidden_layers * 2)
773
- ]
774
- )
775
-
776
- return input_info
777
-
778
- prefill_input_info = get_input_info(
865
+ prefill_input_info = cls.get_input_info(
779
866
  batch_size=1,
780
867
  query_length=rbln_config.prefill_chunk_size,
781
868
  use_inputs_embeds=rbln_config.use_inputs_embeds,
782
- hidden_size=hidden_size,
783
869
  use_attention_mask=rbln_config.use_attention_mask,
784
870
  max_seq_len=rbln_config.max_seq_len,
785
871
  kvcache_block_size=rbln_config.kvcache_block_size,
786
872
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
873
+ num_key_value_heads=num_key_value_heads,
874
+ num_hidden_layers=num_hidden_layers,
875
+ hidden_size=hidden_size,
876
+ head_dim=head_dim,
787
877
  )
788
- dec_input_info = get_input_info(
878
+ dec_input_info = cls.get_input_info(
789
879
  batch_size=rbln_config.batch_size,
790
880
  query_length=1,
791
881
  use_inputs_embeds=rbln_config.use_inputs_embeds,
792
- hidden_size=hidden_size,
793
882
  use_attention_mask=rbln_config.use_attention_mask,
794
883
  max_seq_len=rbln_config.max_seq_len,
795
884
  kvcache_block_size=rbln_config.kvcache_block_size,
796
885
  kvcache_num_blocks=rbln_config.kvcache_num_blocks,
886
+ num_key_value_heads=num_key_value_heads,
887
+ num_hidden_layers=num_hidden_layers,
888
+ hidden_size=hidden_size,
889
+ head_dim=head_dim,
797
890
  )
798
891
 
799
892
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
@@ -864,7 +957,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
864
957
  model_inputs.update({"inputs_embeds": inputs_embeds})
865
958
  else:
866
959
  raise ValueError(
867
- "The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
960
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
868
961
  )
869
962
  else:
870
963
  model_inputs.update({"input_ids": input_ids})
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_qwen2_5_vl import (
16
+ RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
17
+ RBLNQwen2_5_VLForConditionalGenerationConfig,
18
+ )
19
+ from .modeling_qwen2_5_vl import RBLNQwen2_5_VisionTransformerPretrainedModel, RBLNQwen2_5_VLForConditionalGeneration
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Union
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
19
+
20
+
21
+ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMConfig):
22
+ submodules = ["visual"]
23
+
24
+ def __init__(
25
+ self,
26
+ visual: Optional[RBLNModelConfig] = None,
27
+ use_inputs_embeds: bool = True,
28
+ **kwargs,
29
+ ):
30
+ super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
31
+ if not self.use_inputs_embeds:
32
+ raise ValueError(
33
+ "RBLNQwen2_5_VLForConditionalGenerationConfig does not allow `use_inputs_embeds` to be set to False, "
34
+ "as RBLNQwen2_5_VLForConditionalGeneration accepts only `inputs_embeds` as input."
35
+ )
36
+ self.visual = visual
37
+
38
+
39
+ class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
40
+ def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs):
41
+ """
42
+ Args:
43
+ max_seq_lens (Optional[Union[int, List[int]]]): Maximum sequence lengths for Vision
44
+ Transformer attention. Can be an integer or list of integers, each indicating
45
+ the number of patches in a sequence for an image or video. For example, an image
46
+ of 224x196 pixels with patch size 14 and window size 112 has its width padded to
47
+ 224, forming a 224x224 image. This yields 256 patches [(224/14) * (224/14)], so
48
+ `max_seq_len` must be at least 256. For window-based attention, `max_seq_len`
49
+ must be a multiple of `(window_size / patch_size)^2`, e.g., (112/14)^2 = 64,
50
+ making 256 (64 * 4) valid. RBLN optimization runs inference per image or video
51
+ frame, so set `max_seq_len` to match the maximum expected resolution to reduce
52
+ computation. If not provided, a `ValueError` is raised.
53
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
54
+
55
+ Raises:
56
+ ValueError: If batch_size is not a positive integer.
57
+ """
58
+ super().__init__(**kwargs)
59
+
60
+ if max_seq_lens is not None:
61
+ if isinstance(max_seq_lens, int):
62
+ max_seq_lens = [max_seq_lens]
63
+ elif isinstance(max_seq_lens, list):
64
+ max_seq_lens.sort(reverse=True)
65
+ else:
66
+ raise ValueError("'max_seq_lens' must be specified.")
67
+
68
+ self.max_seq_lens = max_seq_lens