optimum-rbln 0.8.4a8__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +7 -6
  61. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -33,7 +33,6 @@ logger = get_logger(__name__)
33
33
 
34
34
 
35
35
  DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
36
- DEFAULT_MOD_NAME = "default"
37
36
  TypeInputInfo = List[Tuple[str, Tuple[int], str]]
38
37
 
39
38
 
@@ -41,6 +40,9 @@ TypeInputInfo = List[Tuple[str, Tuple[int], str]]
41
40
  class RBLNSerializableConfigProtocol(Protocol):
42
41
  def _prepare_for_serialization(self) -> Dict[str, Any]: ...
43
42
 
43
+ def __repr__(self) -> str:
44
+ return f"{self.__class__.__name__}({self._prepare_for_serialization()})"
45
+
44
46
 
45
47
  @dataclass
46
48
  class RBLNCompileConfig:
@@ -49,17 +51,13 @@ class RBLNCompileConfig:
49
51
 
50
52
  Attributes:
51
53
  compiled_model_name (str): Name of the compiled model.
52
- mod_name (str): Name of the RBLN module.
53
54
  input_info (Union[List[TypeInputInfo], TypeInputInfo]): Information about input tensors.
54
- fusion (Optional[bool]): Whether to use fusion optimization.
55
55
  npu (Optional[str]): NPU configuration.
56
56
  tensor_parallel_size (Optional[int]): Size for tensor parallelism.
57
57
  """
58
58
 
59
59
  compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
60
- mod_name: str = DEFAULT_MOD_NAME
61
60
  input_info: Union[List[TypeInputInfo], TypeInputInfo] = None
62
- fusion: Optional[bool] = None
63
61
  npu: Optional[str] = None
64
62
  tensor_parallel_size: Optional[int] = None
65
63
 
@@ -113,9 +111,7 @@ class RBLNCompileConfig:
113
111
 
114
112
  def update(self, kwargs: Dict[str, Any]):
115
113
  self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
116
- self.mod_name = kwargs.get("mod_name", self.mod_name)
117
114
  self.input_info = kwargs.get("input_info", self.input_info)
118
- self.fusion = kwargs.get("fusion", self.fusion)
119
115
  self.npu = kwargs.get("npu", self.npu)
120
116
  self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
121
117
  return self
@@ -149,7 +145,7 @@ class RBLNCompileConfig:
149
145
  return asdict(self)
150
146
 
151
147
 
152
- RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler", "timeout"]
148
+ RUNTIME_KEYWORDS = ["create_runtimes", "device", "device_map", "activate_profiler", "timeout"]
153
149
  CONFIG_MAPPING: Dict[str, Type["RBLNModelConfig"]] = {}
154
150
 
155
151
 
@@ -525,7 +521,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
525
521
  "npu",
526
522
  "tensor_parallel_size",
527
523
  "create_runtimes",
528
- "optimize_host_memory",
529
524
  "device",
530
525
  "device_map",
531
526
  "activate_profiler",
@@ -534,18 +529,18 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
534
529
  submodules: List[str] = []
535
530
  subclass_non_save_attributes = []
536
531
 
537
- def init_submodule_config(
532
+ def initialize_submodule_config(
538
533
  self,
539
- submodule_config_cls: Type["RBLNModelConfig"],
540
534
  submodule_config: Optional[Union[Dict[str, Any], "RBLNModelConfig"]] = None,
535
+ force_kwargs: bool = False,
541
536
  **kwargs: Any,
542
537
  ) -> "RBLNModelConfig":
543
- # Initialize a submodule config from a dict or a RBLNModelConfig.
544
- # kwargs is specified from the predecessor config.
545
-
546
538
  if submodule_config is None:
547
539
  submodule_config = {}
548
540
 
541
+ if isinstance(submodule_config, RBLNModelConfig):
542
+ return submodule_config
543
+
549
544
  if isinstance(submodule_config, dict):
550
545
  from_predecessor = self._runtime_options.copy()
551
546
  from_predecessor.update(
@@ -559,13 +554,60 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
559
554
 
560
555
  init_kwargs = from_predecessor
561
556
  init_kwargs.update(submodule_config)
562
- submodule_config = submodule_config_cls(**init_kwargs)
563
557
 
564
- if not isinstance(submodule_config, submodule_config_cls):
558
+ if force_kwargs:
559
+ for key, value in kwargs.items():
560
+ if key in init_kwargs:
561
+ if init_kwargs[key] != value:
562
+ raise ValueError(
563
+ f"Parameter conflict for '{key}': submodule_config has {init_kwargs[key]}, "
564
+ f"but kwargs has {value}. Using kwargs value: {value}"
565
+ )
566
+ init_kwargs[key] = value
567
+
568
+ if "cls_name" in init_kwargs:
569
+ config_cls = get_rbln_config_class(init_kwargs["cls_name"])
570
+ else:
571
+ return init_kwargs
572
+
573
+ submodule_config = config_cls(**init_kwargs)
574
+
575
+ if not isinstance(submodule_config, RBLNModelConfig):
565
576
  raise TypeError(f"Invalid submodule config type: {type(submodule_config)}")
566
577
 
567
578
  return submodule_config
568
579
 
580
+ def filter_parameters(self, config_cls: Type["RBLNModelConfig"], parameters: Dict[str, Any]) -> Dict[str, Any]:
581
+ import importlib
582
+
583
+ model_cls_name = config_cls.__name__.replace("Config", "")
584
+ modeling_module_name = config_cls.__module__.replace("configuration_", "modeling_")
585
+
586
+ model_cls = None
587
+ try:
588
+ modeling_module = importlib.import_module(modeling_module_name)
589
+ if hasattr(modeling_module, model_cls_name):
590
+ model_cls = getattr(modeling_module, model_cls_name)
591
+ except ImportError:
592
+ logger.debug(f"Could not import modeling module: {modeling_module_name}")
593
+
594
+ filtered_out_params = set()
595
+
596
+ if model_cls is not None:
597
+ if not getattr(model_cls, "_tp_support", False):
598
+ filtered_out_params.add("tensor_parallel_size")
599
+
600
+ filtered_params = {}
601
+ for key, value in parameters.items():
602
+ if key in filtered_out_params:
603
+ logger.debug(
604
+ f"Parameter '{key}' filtered out for {config_cls.__name__} (not supported by model flags)."
605
+ )
606
+ else:
607
+ filtered_params[key] = value
608
+
609
+ return filtered_params
610
+
569
611
  def __setattr__(self, key, value):
570
612
  if (
571
613
  key != "_attributes_map"
@@ -604,7 +646,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
604
646
  self,
605
647
  cls_name: Optional[str] = None,
606
648
  create_runtimes: Optional[bool] = None,
607
- optimize_host_memory: Optional[bool] = None,
608
649
  device: Optional[Union[int, List[int]]] = None,
609
650
  device_map: Optional[Dict[str, Union[int, List[int]]]] = None,
610
651
  activate_profiler: Optional[bool] = None,
@@ -614,6 +655,8 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
614
655
  optimum_rbln_version: Optional[str] = None,
615
656
  _torch_dtype: Optional[str] = None,
616
657
  _compile_cfgs: List[RBLNCompileConfig] = [],
658
+ *,
659
+ optimize_host_memory: Optional[bool] = None,
617
660
  **kwargs: Any,
618
661
  ):
619
662
  """
@@ -622,7 +665,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
622
665
  Args:
623
666
  cls_name (Optional[str]): The class name of the configuration. Defaults to the current class name.
624
667
  create_runtimes (Optional[bool]): Whether to create RBLN runtimes. Defaults to True.
625
- optimize_host_memory (Optional[bool]): Whether to optimize host memory usage. Defaults to True.
626
668
  device (Optional[Union[int, List[int]]]): The device(s) to load the model onto. Can be a single device ID or a list.
627
669
  device_map (Optional[Dict[str, Union[int, List[int]]]]): Mapping from compiled model names to device IDs.
628
670
  activate_profiler (Optional[bool]): Whether to activate the profiler for performance analysis.
@@ -648,12 +690,14 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
648
690
 
649
691
  self._runtime_options = {}
650
692
  self._runtime_options["create_runtimes"] = create_runtimes
651
- self._runtime_options["optimize_host_memory"] = optimize_host_memory
652
693
  self._runtime_options["device"] = device
653
694
  self._runtime_options["device_map"] = device_map
654
695
  self._runtime_options["activate_profiler"] = activate_profiler
655
696
  self._runtime_options["timeout"] = timeout
656
697
 
698
+ if optimize_host_memory is not None:
699
+ logger.warning("`optimize_host_memory` is deprecated and will be removed in future versions.")
700
+
657
701
  # Automatically pass npu, tensor_parallel_size to compile_cfgs
658
702
  self.npu = npu
659
703
  self.tensor_parallel_size = tensor_parallel_size
@@ -871,19 +915,6 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
871
915
  def create_runtimes(self, create_runtimes: bool):
872
916
  self._runtime_options["create_runtimes"] = create_runtimes
873
917
 
874
- @property
875
- def optimize_host_memory(self):
876
- context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
877
- if context is not None:
878
- return context
879
- elif self._runtime_options["optimize_host_memory"] is None:
880
- return True
881
- return self._runtime_options["optimize_host_memory"]
882
-
883
- @optimize_host_memory.setter
884
- def optimize_host_memory(self, optimize_host_memory: bool):
885
- self._runtime_options["optimize_host_memory"] = optimize_host_memory
886
-
887
918
  @property
888
919
  def device(self):
889
920
  context = ContextRblnConfig.get_current_context()["device"]
@@ -93,20 +93,27 @@ class RBLNStableDiffusionControlNetPipelineBaseConfig(RBLNModelConfig):
93
93
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
94
94
  raise ValueError("Both img_height and img_width must be provided together if used")
95
95
 
96
- self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
97
- self.unet = self.init_submodule_config(
98
- RBLNUNet2DConditionModelConfig,
96
+ self.text_encoder = self.initialize_submodule_config(
97
+ text_encoder,
98
+ cls_name="RBLNCLIPTextModelConfig",
99
+ batch_size=batch_size,
100
+ )
101
+ self.unet = self.initialize_submodule_config(
99
102
  unet,
103
+ cls_name="RBLNUNet2DConditionModelConfig",
100
104
  sample_size=sample_size,
101
105
  )
102
- self.vae = self.init_submodule_config(
103
- RBLNAutoencoderKLConfig,
106
+ self.vae = self.initialize_submodule_config(
104
107
  vae,
108
+ cls_name="RBLNAutoencoderKLConfig",
105
109
  batch_size=batch_size,
106
110
  uses_encoder=self.__class__._vae_uses_encoder,
107
111
  sample_size=image_size, # image size is equal to sample size in vae
108
112
  )
109
- self.controlnet = self.init_submodule_config(RBLNControlNetModelConfig, controlnet)
113
+ self.controlnet = self.initialize_submodule_config(
114
+ controlnet,
115
+ cls_name="RBLNControlNetModelConfig",
116
+ )
110
117
 
111
118
  # Get default guidance scale from original class to set UNet and ControlNet batch size
112
119
  if guidance_scale is None:
@@ -235,23 +242,32 @@ class RBLNStableDiffusionXLControlNetPipelineBaseConfig(RBLNModelConfig):
235
242
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
236
243
  raise ValueError("Both img_height and img_width must be provided together if used")
237
244
 
238
- self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
239
- self.text_encoder_2 = self.init_submodule_config(
240
- RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
245
+ self.text_encoder = self.initialize_submodule_config(
246
+ text_encoder,
247
+ cls_name="RBLNCLIPTextModelConfig",
248
+ batch_size=batch_size,
241
249
  )
242
- self.unet = self.init_submodule_config(
243
- RBLNUNet2DConditionModelConfig,
250
+ self.text_encoder_2 = self.initialize_submodule_config(
251
+ text_encoder_2,
252
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
253
+ batch_size=batch_size,
254
+ )
255
+ self.unet = self.initialize_submodule_config(
244
256
  unet,
257
+ cls_name="RBLNUNet2DConditionModelConfig",
245
258
  sample_size=sample_size,
246
259
  )
247
- self.vae = self.init_submodule_config(
248
- RBLNAutoencoderKLConfig,
260
+ self.vae = self.initialize_submodule_config(
249
261
  vae,
262
+ cls_name="RBLNAutoencoderKLConfig",
250
263
  batch_size=batch_size,
251
264
  uses_encoder=self.__class__._vae_uses_encoder,
252
265
  sample_size=image_size, # image size is equal to sample size in vae
253
266
  )
254
- self.controlnet = self.init_submodule_config(RBLNControlNetModelConfig, controlnet)
267
+ self.controlnet = self.initialize_submodule_config(
268
+ controlnet,
269
+ cls_name="RBLNControlNetModelConfig",
270
+ )
255
271
 
256
272
  # Get default guidance scale from original class to set UNet and ControlNet batch size
257
273
  guidance_scale = (
@@ -63,12 +63,15 @@ class RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
63
63
  """
64
64
  super().__init__(**kwargs)
65
65
 
66
- self.text_encoder = self.init_submodule_config(
67
- RBLNT5EncoderModelConfig, text_encoder, batch_size=batch_size, max_seq_len=max_seq_len
66
+ self.text_encoder = self.initialize_submodule_config(
67
+ text_encoder,
68
+ cls_name="RBLNT5EncoderModelConfig",
69
+ batch_size=batch_size,
70
+ max_seq_len=max_seq_len,
68
71
  )
69
- self.transformer = self.init_submodule_config(
70
- RBLNCosmosTransformer3DModelConfig,
72
+ self.transformer = self.initialize_submodule_config(
71
73
  transformer,
74
+ cls_name="RBLNCosmosTransformer3DModelConfig",
72
75
  batch_size=batch_size,
73
76
  max_seq_len=max_seq_len,
74
77
  height=height,
@@ -76,18 +79,18 @@ class RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
76
79
  num_frames=num_frames,
77
80
  fps=fps,
78
81
  )
79
- self.vae = self.init_submodule_config(
80
- RBLNAutoencoderKLCosmosConfig,
82
+ self.vae = self.initialize_submodule_config(
81
83
  vae,
84
+ cls_name="RBLNAutoencoderKLCosmosConfig",
82
85
  batch_size=batch_size,
83
86
  uses_encoder=self.__class__._vae_uses_encoder,
84
87
  height=height,
85
88
  width=width,
86
89
  num_frames=num_frames,
87
90
  )
88
- self.safety_checker = self.init_submodule_config(
89
- RBLNCosmosSafetyCheckerConfig,
91
+ self.safety_checker = self.initialize_submodule_config(
90
92
  safety_checker,
93
+ cls_name="RBLNCosmosSafetyCheckerConfig",
91
94
  batch_size=batch_size,
92
95
  height=height,
93
96
  width=width,
@@ -88,10 +88,14 @@ class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
88
88
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
89
89
  raise ValueError("Both img_height and img_width must be provided together if used")
90
90
 
91
- self.unet = self.init_submodule_config(RBLNUNet2DConditionModelConfig, unet, sample_size=sample_size)
92
- self.movq = self.init_submodule_config(
93
- RBLNVQModelConfig,
91
+ self.unet = self.initialize_submodule_config(
92
+ unet,
93
+ cls_name="RBLNUNet2DConditionModelConfig",
94
+ sample_size=sample_size,
95
+ )
96
+ self.movq = self.initialize_submodule_config(
94
97
  movq,
98
+ cls_name="RBLNVQModelConfig",
95
99
  batch_size=batch_size,
96
100
  sample_size=image_size, # image size is equal to sample size in vae
97
101
  uses_encoder=self._movq_uses_encoder,
@@ -173,14 +177,20 @@ class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
173
177
  accommodate classifier-free guidance.
174
178
  """
175
179
  super().__init__(**kwargs)
176
- self.text_encoder = self.init_submodule_config(
177
- RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
180
+ self.text_encoder = self.initialize_submodule_config(
181
+ text_encoder,
182
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
183
+ batch_size=batch_size,
178
184
  )
179
- self.image_encoder = self.init_submodule_config(
180
- RBLNCLIPVisionModelWithProjectionConfig, image_encoder, batch_size=batch_size
185
+ self.image_encoder = self.initialize_submodule_config(
186
+ image_encoder,
187
+ cls_name="RBLNCLIPVisionModelWithProjectionConfig",
188
+ batch_size=batch_size,
189
+ )
190
+ self.prior = self.initialize_submodule_config(
191
+ prior,
192
+ cls_name="RBLNPriorTransformerConfig",
181
193
  )
182
-
183
- self.prior = self.init_submodule_config(RBLNPriorTransformerConfig, prior)
184
194
 
185
195
  # Get default guidance scale from original class to set UNet batch size
186
196
  if guidance_scale is None:
@@ -286,18 +296,18 @@ class RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
286
296
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
287
297
  raise ValueError("Both img_height and img_width must be provided together if used")
288
298
 
289
- self.prior_pipe = self.init_submodule_config(
290
- RBLNKandinskyV22PriorPipelineConfig,
299
+ self.prior_pipe = self.initialize_submodule_config(
291
300
  prior_pipe,
301
+ cls_name="RBLNKandinskyV22PriorPipelineConfig",
292
302
  prior=prior_prior,
293
303
  image_encoder=prior_image_encoder,
294
304
  text_encoder=prior_text_encoder,
295
305
  batch_size=batch_size,
296
306
  guidance_scale=guidance_scale,
297
307
  )
298
- self.decoder_pipe = self.init_submodule_config(
299
- self._decoder_pipe_cls,
308
+ self.decoder_pipe = self.initialize_submodule_config(
300
309
  decoder_pipe,
310
+ cls_name=self._decoder_pipe_cls.__name__,
301
311
  unet=unet,
302
312
  movq=movq,
303
313
  batch_size=batch_size,
@@ -90,18 +90,22 @@ class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
90
90
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
91
91
  raise ValueError("Both img_height and img_width must be provided together if used")
92
92
 
93
- self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
94
- self.unet = self.init_submodule_config(
95
- RBLNUNet2DConditionModelConfig,
93
+ self.text_encoder = self.initialize_submodule_config(
94
+ text_encoder,
95
+ cls_name="RBLNCLIPTextModelConfig",
96
+ batch_size=batch_size,
97
+ )
98
+ self.unet = self.initialize_submodule_config(
96
99
  unet,
100
+ cls_name="RBLNUNet2DConditionModelConfig",
97
101
  sample_size=sample_size,
98
102
  )
99
- self.vae = self.init_submodule_config(
100
- RBLNAutoencoderKLConfig,
103
+ self.vae = self.initialize_submodule_config(
101
104
  vae,
105
+ cls_name="RBLNAutoencoderKLConfig",
102
106
  batch_size=batch_size,
103
107
  uses_encoder=self.__class__._vae_uses_encoder,
104
- sample_size=image_size, # image size is equal to sample size in vae
108
+ sample_size=image_size,
105
109
  )
106
110
 
107
111
  # Get default guidance scale from original class to set UNet batch size
@@ -100,27 +100,31 @@ class RBLNStableDiffusion3PipelineBaseConfig(RBLNModelConfig):
100
100
 
101
101
  max_seq_len = max_seq_len or 256
102
102
 
103
- self.text_encoder = self.init_submodule_config(
104
- RBLNCLIPTextModelWithProjectionConfig, text_encoder, batch_size=batch_size
103
+ self.text_encoder = self.initialize_submodule_config(
104
+ text_encoder,
105
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
106
+ batch_size=batch_size,
105
107
  )
106
- self.text_encoder_2 = self.init_submodule_config(
107
- RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
108
+ self.text_encoder_2 = self.initialize_submodule_config(
109
+ text_encoder_2,
110
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
111
+ batch_size=batch_size,
108
112
  )
109
- self.text_encoder_3 = self.init_submodule_config(
110
- RBLNT5EncoderModelConfig,
113
+ self.text_encoder_3 = self.initialize_submodule_config(
111
114
  text_encoder_3,
115
+ cls_name="RBLNT5EncoderModelConfig",
112
116
  batch_size=batch_size,
113
117
  max_seq_len=max_seq_len,
114
118
  model_input_names=["input_ids"],
115
119
  )
116
- self.transformer = self.init_submodule_config(
117
- RBLNSD3Transformer2DModelConfig,
120
+ self.transformer = self.initialize_submodule_config(
118
121
  transformer,
122
+ cls_name="RBLNSD3Transformer2DModelConfig",
119
123
  sample_size=sample_size,
120
124
  )
121
- self.vae = self.init_submodule_config(
122
- RBLNAutoencoderKLConfig,
125
+ self.vae = self.initialize_submodule_config(
123
126
  vae,
127
+ cls_name="RBLNAutoencoderKLConfig",
124
128
  batch_size=batch_size,
125
129
  uses_encoder=self.__class__._vae_uses_encoder,
126
130
  sample_size=image_size,
@@ -93,18 +93,25 @@ class RBLNStableDiffusionXLPipelineBaseConfig(RBLNModelConfig):
93
93
  elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
94
94
  raise ValueError("Both img_height and img_width must be provided together if used")
95
95
 
96
- self.text_encoder = self.init_submodule_config(RBLNCLIPTextModelConfig, text_encoder, batch_size=batch_size)
97
- self.text_encoder_2 = self.init_submodule_config(
98
- RBLNCLIPTextModelWithProjectionConfig, text_encoder_2, batch_size=batch_size
96
+ self.text_encoder = self.initialize_submodule_config(
97
+ text_encoder,
98
+ cls_name="RBLNCLIPTextModelConfig",
99
+ batch_size=batch_size,
100
+ )
101
+ self.text_encoder_2 = self.initialize_submodule_config(
102
+ text_encoder_2,
103
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
104
+ batch_size=batch_size,
99
105
  )
100
- self.unet = self.init_submodule_config(
101
- RBLNUNet2DConditionModelConfig,
106
+
107
+ self.unet = self.initialize_submodule_config(
102
108
  unet,
109
+ cls_name="RBLNUNet2DConditionModelConfig",
103
110
  sample_size=sample_size,
104
111
  )
105
- self.vae = self.init_submodule_config(
106
- RBLNAutoencoderKLConfig,
112
+ self.vae = self.initialize_submodule_config(
107
113
  vae,
114
+ cls_name="RBLNAutoencoderKLConfig",
108
115
  batch_size=batch_size,
109
116
  uses_encoder=self.__class__._vae_uses_encoder,
110
117
  sample_size=image_size, # image size is equal to sample size in vae
@@ -244,7 +244,6 @@ class RBLNDiffusionMixin:
244
244
  device=rbln_config.device,
245
245
  device_map=rbln_config.device_map,
246
246
  create_runtimes=rbln_config.create_runtimes,
247
- optimize_host_mem=rbln_config.optimize_host_memory,
248
247
  activate_profiler=rbln_config.activate_profiler,
249
248
  timeout=rbln_config.timeout,
250
249
  ):
@@ -412,12 +411,11 @@ class RBLNDiffusionMixin:
412
411
  # overwrite to replace incorrect config
413
412
  model.save_config(model_save_dir)
414
413
 
415
- if rbln_config.optimize_host_memory is False:
416
- # Keep compiled_model objs to further analysis. -> TODO: remove soon...
417
- model.compiled_models = []
418
- for name in cls._submodules:
419
- submodule = getattr(model, name)
420
- model.compiled_models.extend(submodule.compiled_models)
414
+ # Keep compiled_model objs to further analysis. -> TODO: remove soon...
415
+ model.compiled_models = []
416
+ for name in cls._submodules:
417
+ submodule = getattr(model, name)
418
+ model.compiled_models.extend(submodule.compiled_models)
421
419
 
422
420
  return model
423
421
 
@@ -15,7 +15,7 @@
15
15
  from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
- from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
18
+ from ....transformers import RBLNSiglipVisionModelConfig
19
19
 
20
20
 
21
21
  class RBLNVideoSafetyModelConfig(RBLNModelConfig):
@@ -81,30 +81,28 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
81
81
 
82
82
  tensor_parallel_size = kwargs.get("tensor_parallel_size")
83
83
 
84
- self.llamaguard3 = self.init_submodule_config(
85
- RBLNLlamaForCausalLMConfig,
84
+ self.llamaguard3 = self.initialize_submodule_config(
86
85
  llamaguard3,
86
+ cls_name="RBLNLlamaForCausalLMConfig",
87
87
  batch_size=batch_size,
88
88
  tensor_parallel_size=tensor_parallel_size,
89
89
  max_seq_len=max_seq_len,
90
90
  )
91
-
92
- self.siglip_encoder = self.init_submodule_config(
93
- RBLNSiglipVisionModelConfig,
91
+ self.siglip_encoder = self.initialize_submodule_config(
94
92
  siglip_encoder,
93
+ cls_name="RBLNSiglipVisionModelConfig",
95
94
  batch_size=batch_size,
96
95
  image_size=(384, 384),
97
96
  )
98
-
99
- self.video_safety_model = self.init_submodule_config(
100
- RBLNVideoSafetyModelConfig,
97
+ self.video_safety_model = self.initialize_submodule_config(
101
98
  video_safety_model,
99
+ cls_name="RBLNVideoSafetyModelConfig",
102
100
  batch_size=batch_size,
103
101
  input_size=1152,
104
102
  )
105
- self.face_blur_filter = self.init_submodule_config(
106
- RBLNRetinaFaceFilterConfig,
103
+ self.face_blur_filter = self.initialize_submodule_config(
107
104
  face_blur_filter,
105
+ cls_name="RBLNRetinaFaceFilterConfig",
108
106
  batch_size=batch_size,
109
107
  image_size=image_size,
110
108
  )
optimum/rbln/modeling.py CHANGED
@@ -34,6 +34,49 @@ if TYPE_CHECKING:
34
34
  logger = get_logger(__name__)
35
35
 
36
36
 
37
+ def _get_dtype(
38
+ cls,
39
+ dtype: Optional[Union[str, torch.dtype, dict]],
40
+ config: PretrainedConfig,
41
+ ) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
42
+ dtype_orig = None
43
+
44
+ if dtype is not None:
45
+ if isinstance(dtype, str):
46
+ if dtype == "auto":
47
+ if hasattr(config, "dtype") and config.dtype is not None:
48
+ dtype = config.dtype
49
+ else:
50
+ dtype = torch.get_default_dtype()
51
+ elif hasattr(torch, dtype):
52
+ dtype = getattr(torch, dtype)
53
+ config.dtype = dtype
54
+ elif isinstance(dtype, torch.dtype):
55
+ config.dtype = dtype
56
+ elif isinstance(dtype, dict):
57
+ for key, curr_dtype in dtype.items():
58
+ if hasattr(config, key):
59
+ value = getattr(config, key)
60
+ curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
61
+ value.dtype = curr_dtype
62
+ # main torch dtype for modules that aren't part of any sub-config
63
+ dtype = dtype.get("")
64
+ dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
65
+ config.dtype = dtype
66
+ if dtype is None:
67
+ dtype = torch.float32
68
+ else:
69
+ raise ValueError(f"Invalid dtype: {dtype}")
70
+
71
+ dtype_orig = cls._set_default_dtype(dtype)
72
+ else:
73
+ # Use default dtype
74
+ default_dtype = torch.get_default_dtype()
75
+ config.dtype = default_dtype
76
+
77
+ return config, dtype, dtype_orig
78
+
79
+
37
80
  class RBLNModel(RBLNBaseModel):
38
81
  @classmethod
39
82
  def update_kwargs(cls, kwargs):
@@ -70,6 +113,10 @@ class RBLNModel(RBLNBaseModel):
70
113
  )
71
114
  return compiled_model
72
115
 
116
+ @classmethod
117
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
118
+ return model
119
+
73
120
  @classmethod
74
121
  def from_model(
75
122
  cls,
@@ -103,6 +150,8 @@ class RBLNModel(RBLNBaseModel):
103
150
  Returns:
104
151
  (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
105
152
  """
153
+
154
+ model = cls._reconstruct_model_if_needed(model)
106
155
  preprocessors = kwargs.pop("preprocessors", [])
107
156
  rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
108
157
 
@@ -209,6 +258,7 @@ class RBLNModel(RBLNBaseModel):
209
258
  **kwargs,
210
259
  ) -> "PreTrainedModel":
211
260
  kwargs = cls.update_kwargs(kwargs)
261
+
212
262
  return cls.get_hf_class().from_pretrained(
213
263
  model_id,
214
264
  subfolder=subfolder,
@@ -315,7 +315,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
315
315
  rbln_config,
316
316
  model_save_dir=model_save_dir,
317
317
  subfolder=subfolder,
318
- rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
318
+ rbln_compiled_models=rbln_compiled_models,
319
319
  rbln_submodules=rbln_submodules,
320
320
  **kwargs,
321
321
  )
@@ -433,7 +433,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
433
433
  compiled_model = rebel.compile_from_torch(
434
434
  model,
435
435
  input_info=rbln_compile_config.input_info,
436
- fusion=rbln_compile_config.fusion,
437
436
  npu=rbln_compile_config.npu,
438
437
  tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
439
438
  **kwargs,