optimum-rbln 0.8.2a5__py3-none-any.whl → 0.8.2a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

optimum/rbln/__init__.py CHANGED
@@ -102,6 +102,8 @@ _import_structure = {
102
102
  "RBLNLlamaModelConfig",
103
103
  "RBLNOPTForCausalLM",
104
104
  "RBLNOPTForCausalLMConfig",
105
+ "RBLNLlavaForConditionalGeneration",
106
+ "RBLNLlavaForConditionalGenerationConfig",
105
107
  "RBLNLlavaNextForConditionalGeneration",
106
108
  "RBLNLlavaNextForConditionalGenerationConfig",
107
109
  "RBLNMidmLMHeadModel",
@@ -118,6 +120,8 @@ _import_structure = {
118
120
  "RBLNPegasusModelConfig",
119
121
  "RBLNPhiForCausalLM",
120
122
  "RBLNPhiForCausalLMConfig",
123
+ "RBLNPixtralVisionModel",
124
+ "RBLNPixtralVisionModelConfig",
121
125
  "RBLNPhiModel",
122
126
  "RBLNPhiModelConfig",
123
127
  "RBLNQwen2ForCausalLM",
@@ -369,6 +373,8 @@ if TYPE_CHECKING:
369
373
  RBLNLlamaForCausalLMConfig,
370
374
  RBLNLlamaModel,
371
375
  RBLNLlamaModelConfig,
376
+ RBLNLlavaForConditionalGeneration,
377
+ RBLNLlavaForConditionalGenerationConfig,
372
378
  RBLNLlavaNextForConditionalGeneration,
373
379
  RBLNLlavaNextForConditionalGenerationConfig,
374
380
  RBLNMidmLMHeadModel,
@@ -389,6 +395,8 @@ if TYPE_CHECKING:
389
395
  RBLNPhiForCausalLMConfig,
390
396
  RBLNPhiModel,
391
397
  RBLNPhiModelConfig,
398
+ RBLNPixtralVisionModel,
399
+ RBLNPixtralVisionModelConfig,
392
400
  RBLNQwen2_5_VisionTransformerPretrainedModel,
393
401
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
394
402
  RBLNQwen2_5_VLForConditionalGeneration,
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.8.2a5'
21
- __version_tuple__ = version_tuple = (0, 8, 2, 'a5')
20
+ __version__ = version = '0.8.2a6'
21
+ __version_tuple__ = version_tuple = (0, 8, 2, 'a6')
@@ -86,6 +86,8 @@ _import_structure = {
86
86
  "RBLNIdefics3VisionTransformerConfig",
87
87
  "RBLNLlamaForCausalLM",
88
88
  "RBLNLlamaForCausalLMConfig",
89
+ "RBLNLlavaForConditionalGeneration",
90
+ "RBLNLlavaForConditionalGenerationConfig",
89
91
  "RBLNLlamaModel",
90
92
  "RBLNLlamaModelConfig",
91
93
  "RBLNOPTForCausalLM",
@@ -108,6 +110,8 @@ _import_structure = {
108
110
  "RBLNOPTModelConfig",
109
111
  "RBLNPhiForCausalLM",
110
112
  "RBLNPhiForCausalLMConfig",
113
+ "RBLNPixtralVisionModelConfig",
114
+ "RBLNPixtralVisionModel",
111
115
  "RBLNPhiModel",
112
116
  "RBLNPhiModelConfig",
113
117
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
@@ -220,6 +224,8 @@ if TYPE_CHECKING:
220
224
  RBLNLlamaForCausalLMConfig,
221
225
  RBLNLlamaModel,
222
226
  RBLNLlamaModelConfig,
227
+ RBLNLlavaForConditionalGeneration,
228
+ RBLNLlavaForConditionalGenerationConfig,
223
229
  RBLNLlavaNextForConditionalGeneration,
224
230
  RBLNLlavaNextForConditionalGenerationConfig,
225
231
  RBLNMidmLMHeadModel,
@@ -240,6 +246,8 @@ if TYPE_CHECKING:
240
246
  RBLNPhiForCausalLMConfig,
241
247
  RBLNPhiModel,
242
248
  RBLNPhiModelConfig,
249
+ RBLNPixtralVisionModel,
250
+ RBLNPixtralVisionModelConfig,
243
251
  RBLNQwen2_5_VisionTransformerPretrainedModel,
244
252
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
245
253
  RBLNQwen2_5_VLForConditionalGeneration,
@@ -106,6 +106,7 @@ _import_structure = {
106
106
  "RBLNIdefics3ForConditionalGenerationConfig",
107
107
  "RBLNIdefics3VisionTransformerConfig",
108
108
  ],
109
+ "llava": ["RBLNLlavaForConditionalGeneration", "RBLNLlavaForConditionalGenerationConfig"],
109
110
  "llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig", "RBLNLlamaModel", "RBLNLlamaModelConfig"],
110
111
  "opt": ["RBLNOPTForCausalLM", "RBLNOPTForCausalLMConfig", "RBLNOPTModel", "RBLNOPTModelConfig"],
111
112
  "pegasus": [
@@ -116,6 +117,7 @@ _import_structure = {
116
117
  ],
117
118
  "llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
118
119
  "midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
120
+ "pixtral": ["RBLNPixtralVisionModel", "RBLNPixtralVisionModelConfig"],
119
121
  "mistral": [
120
122
  "RBLNMistralForCausalLM",
121
123
  "RBLNMistralForCausalLMConfig",
@@ -241,6 +243,7 @@ if TYPE_CHECKING:
241
243
  RBLNIdefics3VisionTransformerConfig,
242
244
  )
243
245
  from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig, RBLNLlamaModel, RBLNLlamaModelConfig
246
+ from .llava import RBLNLlavaForConditionalGeneration, RBLNLlavaForConditionalGenerationConfig
244
247
  from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
245
248
  from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
246
249
  from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
@@ -252,6 +255,7 @@ if TYPE_CHECKING:
252
255
  RBLNPegasusModelConfig,
253
256
  )
254
257
  from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig, RBLNPhiModel, RBLNPhiModelConfig
258
+ from .pixtral import RBLNPixtralVisionModel, RBLNPixtralVisionModelConfig
255
259
  from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig, RBLNQwen2Model, RBLNQwen2ModelConfig
256
260
  from .qwen2_5_vl import (
257
261
  RBLNQwen2_5_VisionTransformerPretrainedModel,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Literal, Optional, Union
15
+ from typing import Any, Dict, List, Literal, Optional, Union, get_args
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -352,6 +352,8 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
352
352
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
353
353
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
354
354
 
355
+ if phases is not None:
356
+ self.validate_phases_type(phases)
355
357
  self.phases = phases or ["prefill", "decode"]
356
358
 
357
359
  if "decode" in self.phases:
@@ -374,6 +376,13 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
374
376
  # Larger batch size should be at the beginning of the list.
375
377
  self.decoder_batch_sizes.sort(reverse=True)
376
378
 
379
+ @staticmethod
380
+ def validate_phases_type(phases: List[PhaseType]):
381
+ if not isinstance(phases, list):
382
+ raise ValueError("`phases` must be a list.")
383
+ if not all(phase in get_args(PhaseType) for phase in phases):
384
+ raise ValueError(f"All elements in `phases` must be of type `PhaseType`({get_args(PhaseType)}).")
385
+
377
386
  @property
378
387
  def use_multiple_decoder(self):
379
388
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
@@ -1024,8 +1024,9 @@ class SlidingWindowAttentionOp(AttentionOp):
1024
1024
  "block_size": block_size,
1025
1025
  }
1026
1026
 
1027
- if "prefill" in self.phase:
1028
- op_args["is_bidirectional"] = True
1027
+ if self.phase == "prefill" or self.phase == "image_prefill":
1028
+ if not self.use_attention_mask or self.use_position_ids:
1029
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1029
1030
 
1030
1031
  attn_op_name = self.get_attn_op_name()
1031
1032
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
@@ -1213,7 +1213,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
1213
1213
  kvcache_block_size=rbln_config.kvcache_block_size,
1214
1214
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1215
1215
  n_model_params=sum(p.numel() for p in model.parameters()),
1216
- num_runtimes=1 if rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1216
+ num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1217
1217
  )
1218
1218
 
1219
1219
  max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
@@ -1395,8 +1395,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
1395
1395
  # The decoder stage operates as usual, processing inputs in batch mode.
1396
1396
 
1397
1397
  # for only use forward
1398
- if not self.can_generate():
1399
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1398
+ if generate_idx is None:
1399
+ generate_idx = (
1400
+ attention_mask.sum(dim=-1, keepdim=True).int()
1401
+ if attention_mask is not None
1402
+ else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
1403
+ )
1400
1404
  padded_cache_lengths = torch.zeros_like(generate_idx)
1401
1405
 
1402
1406
  # Prefll
@@ -37,6 +37,14 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
37
37
  )
38
38
  self.image_prefill_chunk_size = image_prefill_chunk_size
39
39
 
40
+ @property
41
+ def use_image_prefill(self):
42
+ return self.image_prefill_chunk_size is not None
43
+
44
+ @property
45
+ def decoder_runtime_idx(self):
46
+ return 2 if self.use_image_prefill else 1
47
+
40
48
 
41
49
  class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
42
50
  submodules = ["vision_tower", "language_model"]
@@ -337,11 +337,12 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
337
337
  chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
338
338
 
339
339
  # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
340
- padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
341
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
342
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
343
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
344
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
340
+ if self.rbln_config.use_image_prefill:
341
+ padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
342
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
343
+ cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
344
+ position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
345
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
345
346
 
346
347
  return (
347
348
  inputs,
@@ -389,7 +390,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
389
390
  step = 0
390
391
  while step < query_length:
391
392
  # Check if the prefill chunk is an image prefill
392
- is_image_prefill = torch.all(
393
+ is_image_prefill = self.rbln_config.use_image_prefill and torch.all(
393
394
  token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
394
395
  )
395
396
  prefill_chunk_size = (
@@ -397,8 +398,10 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
397
398
  )
398
399
 
399
400
  # Check if the prefill chunk is a text prefill which have image_tokens in it.
400
- is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
401
- token_type_ids[:, step : step + prefill_chunk_size] == 1
401
+ is_text_prefill_with_image_tokens = (
402
+ self.rbln_config.use_image_prefill
403
+ and not is_image_prefill
404
+ and torch.any(token_type_ids[:, step : step + prefill_chunk_size] == 1)
402
405
  )
403
406
 
404
407
  # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
@@ -418,7 +421,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
418
421
  num_processed_tokens = prefill_chunk_size
419
422
  if is_text_prefill_with_image_tokens:
420
423
  first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
421
- num_processed_tokens = first_image_token_idx
424
+ num_processed_tokens = first_image_token_idx.item()
422
425
  if is_last_chunk:
423
426
  num_processed_tokens = query_length - step
424
427
 
@@ -548,9 +551,10 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
548
551
  dtype=torch.int16,
549
552
  ).fill_(-1)
550
553
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
554
+
551
555
  self.prefill_decoder = RBLNGemma3RuntimeModel(
552
556
  runtime=self.model[0],
553
- image_prefill=self.model[1],
557
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
554
558
  main_input_name=main_input_name,
555
559
  embed_tokens=self.embed_tokens,
556
560
  phase="prefill",
@@ -565,7 +569,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
565
569
  self.decoders = {}
566
570
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
567
571
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
568
- runtime=self.model[i + 2],
572
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
569
573
  main_input_name=main_input_name,
570
574
  embed_tokens=self.embed_tokens,
571
575
  phase="decode",
@@ -628,20 +632,21 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
628
632
  if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
629
633
  raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
630
634
 
631
- # Update image prefill compile config
632
- img_prefill_input_info = cls.get_input_info(
633
- batch_size=1,
634
- query_length=rbln_config.image_prefill_chunk_size,
635
- rbln_config=rbln_config,
636
- model_config=model_config,
637
- )
638
- image_prefill_compile_config = RBLNCompileConfig(
639
- compiled_model_name="image_prefill", input_info=img_prefill_input_info
640
- )
641
- # Insert image_prefill compile config at index 1
642
- compile_cfgs = rbln_config.compile_cfgs
643
- compile_cfgs.insert(1, image_prefill_compile_config)
644
- rbln_config.set_compile_cfgs(compile_cfgs)
635
+ if rbln_config.use_image_prefill:
636
+ # Update image prefill compile config
637
+ img_prefill_input_info = cls.get_input_info(
638
+ batch_size=1,
639
+ query_length=rbln_config.image_prefill_chunk_size,
640
+ rbln_config=rbln_config,
641
+ model_config=model_config,
642
+ )
643
+ image_prefill_compile_config = RBLNCompileConfig(
644
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
645
+ )
646
+ # Insert image_prefill compile config at index 1
647
+ compile_cfgs = rbln_config.compile_cfgs
648
+ compile_cfgs.insert(1, image_prefill_compile_config)
649
+ rbln_config.set_compile_cfgs(compile_cfgs)
645
650
 
646
651
  return rbln_config
647
652
 
@@ -694,23 +699,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
694
699
  context,
695
700
  rbln_config.quantization,
696
701
  )
702
+ compiled_models = {"prefill": compiled_prefill}
697
703
 
698
- image_prefill_compile_config = rbln_compile_configs[1]
699
- image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
700
- fill=0, static_tensors=static_tensors
701
- )
702
- wrapped_model.phase = "image_prefill"
703
- compiled_image_prefill = compile_model(
704
- wrapped_model,
705
- image_prefill_compile_config,
706
- image_prefill_example_inputs,
707
- context,
708
- rbln_config.quantization,
709
- )
704
+ if rbln_config.use_image_prefill:
705
+ image_prefill_compile_config = rbln_compile_configs[1]
706
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
707
+ fill=0, static_tensors=static_tensors
708
+ )
709
+ wrapped_model.phase = "image_prefill"
710
+ compiled_image_prefill = compile_model(
711
+ wrapped_model,
712
+ image_prefill_compile_config,
713
+ image_prefill_example_inputs,
714
+ context,
715
+ rbln_config.quantization,
716
+ )
717
+ compiled_models["image_prefill"] = compiled_image_prefill
710
718
 
711
- compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
712
719
  wrapped_model.phase = "decode"
713
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[2:]):
720
+ for batch_size, dec_compile_config in zip(
721
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
722
+ ):
714
723
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
715
724
  compiled_decoder = compile_model(
716
725
  wrapped_model,
@@ -731,35 +740,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
731
740
  ) -> List[rebel.Runtime]:
732
741
  expected_model_names = [
733
742
  "prefill",
734
- "image_prefill",
735
743
  *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
736
744
  ]
745
+ if rbln_config.use_image_prefill:
746
+ expected_model_names.insert(1, "image_prefill")
747
+
737
748
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
738
749
  cls._raise_missing_compiled_file_error(expected_model_names)
739
750
 
740
- return [
751
+ ret_val = [
741
752
  rebel.Runtime(
742
753
  compiled_models[0],
743
754
  tensor_type="pt",
744
755
  device=rbln_config.device_map["prefill"],
745
756
  activate_profiler=rbln_config.activate_profiler,
746
757
  timeout=rbln_config.timeout,
747
- ),
748
- rebel.Runtime(
749
- compiled_models[1],
750
- tensor_type="pt",
751
- device=rbln_config.device_map["image_prefill"],
752
- activate_profiler=rbln_config.activate_profiler,
753
- timeout=rbln_config.timeout,
754
- ),
755
- *[
758
+ )
759
+ ]
760
+ if rbln_config.use_image_prefill:
761
+ ret_val.append(
762
+ rebel.Runtime(
763
+ compiled_models[1],
764
+ tensor_type="pt",
765
+ device=rbln_config.device_map["image_prefill"],
766
+ activate_profiler=rbln_config.activate_profiler,
767
+ timeout=rbln_config.timeout,
768
+ ),
769
+ )
770
+
771
+ ret_val.extend(
772
+ [
756
773
  rebel.Runtime(
757
- compiled_models[i + 2],
774
+ compiled_models[i + rbln_config.decoder_runtime_idx],
758
775
  tensor_type="pt",
759
776
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
760
777
  activate_profiler=rbln_config.activate_profiler,
761
778
  timeout=rbln_config.timeout,
762
779
  )
763
780
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
764
- ],
765
- ]
781
+ ]
782
+ )
783
+
784
+ return ret_val
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_llava import RBLNLlavaForConditionalGenerationConfig
16
+ from .modeling_llava import RBLNLlavaForConditionalGeneration
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNLlavaForConditionalGenerationConfig.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized LLaVA models for multimodal conditional generation tasks
26
+ that combine vision and language processing capabilities.
27
+ """
28
+
29
+ submodules = ["vision_tower", "language_model"]
30
+
31
+ def __init__(
32
+ self,
33
+ batch_size: Optional[int] = None,
34
+ vision_tower: Optional[RBLNModelConfig] = None,
35
+ language_model: Optional[RBLNModelConfig] = None,
36
+ **kwargs: Dict[str, Any],
37
+ ):
38
+ """
39
+ Args:
40
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
41
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
42
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
43
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
44
+
45
+ Raises:
46
+ ValueError: If batch_size is not a positive integer.
47
+ """
48
+ super().__init__(**kwargs)
49
+ self.batch_size = batch_size or 1
50
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
51
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
52
+
53
+ self.vision_tower = vision_tower
54
+ self.language_model = language_model