optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling.py CHANGED
@@ -14,15 +14,16 @@
14
14
 
15
15
  from pathlib import Path
16
16
  from tempfile import TemporaryDirectory
17
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
18
18
 
19
19
  import rebel
20
20
  import torch
21
21
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
22
22
  from transformers import AutoConfig, PretrainedConfig
23
+ from transformers.modeling_outputs import BaseModelOutput
23
24
 
25
+ from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
24
26
  from .modeling_base import RBLNBaseModel
25
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, use_rbln_config
26
27
  from .utils.logging import get_logger
27
28
 
28
29
 
@@ -48,6 +49,9 @@ class RBLNModel(RBLNBaseModel):
48
49
  ```
49
50
  """
50
51
 
52
+ output_class = None
53
+ output_key = "last_hidden_state"
54
+
51
55
  @classmethod
52
56
  def update_kwargs(cls, kwargs):
53
57
  """
@@ -56,12 +60,7 @@ class RBLNModel(RBLNBaseModel):
56
60
  For example, `torchscript`=True should be set because torch.jit
57
61
  does not support `transformers` output instances as module output;
58
62
  """
59
- kwargs.update(
60
- {
61
- "torchscript": True,
62
- "return_dict": False,
63
- }
64
- )
63
+ kwargs.update({"torchscript": True})
65
64
  return kwargs
66
65
 
67
66
  @classmethod
@@ -70,7 +69,7 @@ class RBLNModel(RBLNBaseModel):
70
69
  model: "PreTrainedModel",
71
70
  save_dir_path: Path,
72
71
  subfolder: str,
73
- rbln_config: RBLNConfig,
72
+ rbln_config: RBLNModelConfig,
74
73
  ):
75
74
  """
76
75
  If you are unavoidably running on a CPU rather than an RBLN device,
@@ -78,30 +77,29 @@ class RBLNModel(RBLNBaseModel):
78
77
  """
79
78
 
80
79
  @classmethod
81
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
80
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
82
81
  # Wrap the model if needed.
83
82
  return model
84
83
 
85
84
  @classmethod
86
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
85
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
87
86
  model = cls.wrap_model_if_needed(model, rbln_config)
88
87
  rbln_compile_config = rbln_config.compile_cfgs[0]
89
88
  compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
90
89
  return compiled_model
91
90
 
92
91
  @classmethod
93
- @use_rbln_config
94
92
  def from_model(
95
93
  cls,
96
94
  model: "PreTrainedModel",
97
95
  config: Optional[PretrainedConfig] = None,
98
- rbln_config: Dict[str, Any] = {},
96
+ rbln_config: Optional[RBLNModelConfig] = None,
99
97
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
100
98
  subfolder: str = "",
101
99
  **kwargs,
102
100
  ):
103
101
  preprocessors = kwargs.pop("preprocessors", [])
104
- rbln_kwargs = rbln_config
102
+ rbln_config, kwargs = cls.prepare_rbln_config(rbln_config=rbln_config, **kwargs)
105
103
 
106
104
  # Directory to save compile artifacts(.rbln) and original configs
107
105
  if model_save_dir is None:
@@ -123,8 +121,15 @@ class RBLNModel(RBLNBaseModel):
123
121
  config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
124
122
 
125
123
  if hasattr(model, "can_generate") and model.can_generate():
124
+ import json
125
+
126
126
  generation_config = model.generation_config
127
- generation_config.save_pretrained(save_dir_path / subfolder)
127
+ generation_config_path = save_dir_path / subfolder / "generation_config.json"
128
+
129
+ generation_config.save_pretrained(generation_config_path.parent)
130
+ local_config = json.loads(generation_config_path.read_text(encoding="utf-8"))
131
+ local_config["transformers_version"] = generation_config.transformers_version
132
+ generation_config_path.write_text(json.dumps(local_config, indent=2) + "\n", encoding="utf-8")
128
133
 
129
134
  if not isinstance(config, PretrainedConfig): # diffusers config
130
135
  config = PretrainedConfig(**config)
@@ -134,14 +139,21 @@ class RBLNModel(RBLNBaseModel):
134
139
  for preprocessor in preprocessors:
135
140
  preprocessor.save_pretrained(save_dir_path / subfolder)
136
141
 
137
- # ad-hoc
138
- rbln_kwargs["n_model_params"] = sum(p.numel() for p in model.parameters())
142
+ # Load submodules
143
+ if len(cls._rbln_submodules) > 0:
144
+ rbln_submodules = cls._load_submodules(
145
+ model=model,
146
+ model_save_dir=save_dir,
147
+ rbln_config=rbln_config,
148
+ **kwargs,
149
+ )
150
+ else:
151
+ rbln_submodules = []
139
152
 
140
153
  # Get compilation arguments (e.g. input_info)
141
- rbln_config: RBLNConfig = cls.get_rbln_config(
142
- preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
154
+ rbln_config: RBLNModelConfig = cls.update_rbln_config(
155
+ preprocessors=preprocessors, model=model, model_config=config, rbln_config=rbln_config
143
156
  )
144
- # rbln_config.update_runtime_cfg(rbln_kwargs) # This is done in get_rbln_config
145
157
 
146
158
  compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
147
159
  model, rbln_config=rbln_config
@@ -160,17 +172,6 @@ class RBLNModel(RBLNBaseModel):
160
172
  # Save torch artifacts (e.g. embedding matrix if needed.)
161
173
  cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
162
174
 
163
- # Load submodules
164
- if len(cls._rbln_submodules) > 0:
165
- rbln_submodules = cls._load_submodules(
166
- model=model,
167
- model_save_dir=save_dir,
168
- rbln_kwargs=rbln_kwargs,
169
- **kwargs,
170
- )
171
- else:
172
- rbln_submodules = []
173
-
174
175
  # Instantiate
175
176
  return cls._from_pretrained(
176
177
  model_id=save_dir_path,
@@ -194,8 +195,8 @@ class RBLNModel(RBLNBaseModel):
194
195
  subfolder: str = "",
195
196
  local_files_only: bool = False,
196
197
  trust_remote_code: bool = False,
197
- # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
198
- rbln_kwargs: Optional[Dict[str, Any]] = None,
198
+ # Some rbln-config should be applied before loading torch module (i.e. quantized llm)
199
+ rbln_config: Optional[RBLNModelConfig] = None,
199
200
  **kwargs,
200
201
  ) -> "PreTrainedModel":
201
202
  kwargs = cls.update_kwargs(kwargs)
@@ -215,18 +216,43 @@ class RBLNModel(RBLNBaseModel):
215
216
  def _create_runtimes(
216
217
  cls,
217
218
  compiled_models: List[rebel.RBLNCompiledModel],
218
- rbln_device_map: Dict[str, int],
219
- activate_profiler: Optional[bool] = None,
219
+ rbln_config: RBLNModelConfig,
220
220
  ) -> List[rebel.Runtime]:
221
- if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
221
+ if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
222
222
  cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
223
223
 
224
- device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
225
224
  return [
226
- compiled_model.create_runtime(tensor_type="pt", device=device, activate_profiler=activate_profiler)
225
+ rebel.Runtime(
226
+ compiled_model,
227
+ tensor_type="pt",
228
+ device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
229
+ activate_profiler=rbln_config.activate_profiler,
230
+ )
227
231
  for compiled_model in compiled_models
228
232
  ]
229
233
 
230
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
234
+ def forward(self, *args, return_dict: Optional[bool] = None, **kwargs):
235
+ if self.hf_library_name == "transformers":
236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
237
+ else:
238
+ return_dict = True if return_dict is None else return_dict
239
+
240
+ # Get output from the model
231
241
  output = self.model[0](*args, **kwargs)
232
- return output
242
+
243
+ # Format output according to task requirements
244
+ return self._prepare_output(output, return_dict)
245
+
246
+ def _prepare_output(self, output, return_dict):
247
+ """
248
+ Prepare model output based on return_dict flag.
249
+ This method can be overridden by subclasses to provide task-specific output handling.
250
+ """
251
+ if not return_dict:
252
+ return (output,) if not isinstance(output, (tuple, list)) else output
253
+ else:
254
+ if self.output_class is None:
255
+ return BaseModelOutput(last_hidden_state=output)
256
+
257
+ # Create output with the appropriate class and key
258
+ return self.output_class(**{self.output_key: output})
@@ -18,18 +18,13 @@ import shutil
18
18
  from abc import ABC, abstractmethod
19
19
  from pathlib import Path
20
20
  from tempfile import TemporaryDirectory
21
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
21
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
22
22
 
23
23
  import rebel
24
24
  import torch
25
- from transformers import (
26
- AutoConfig,
27
- AutoModel,
28
- GenerationConfig,
29
- PretrainedConfig,
30
- )
31
-
32
- from .modeling_config import RBLNCompileConfig, RBLNConfig, use_rbln_config
25
+ from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
+
27
+ from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig
33
28
  from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
34
29
  from .utils.logging import get_logger
35
30
  from .utils.runtime_utils import UnavailableRuntime
@@ -47,6 +42,10 @@ class PreTrainedModel(ABC): # noqa: F811
47
42
  pass
48
43
 
49
44
 
45
+ class RBLNBaseModelConfig(RBLNModelConfig):
46
+ pass
47
+
48
+
50
49
  class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
51
50
  """
52
51
  An abstract base class for compiling, loading, and saving neural network models from the huggingface
@@ -85,15 +84,17 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
85
84
  model_type = "rbln_model"
86
85
  auto_model_class = AutoModel
87
86
  config_class = AutoConfig
87
+
88
88
  config_name = "config.json"
89
89
  hf_library_name = "transformers"
90
90
  _hf_class = None
91
+ _rbln_config_class = None
91
92
 
92
93
  def __init__(
93
94
  self,
94
95
  models: List[rebel.Runtime],
95
96
  config: "PretrainedConfig",
96
- rbln_config: RBLNConfig,
97
+ rbln_config: RBLNModelConfig,
97
98
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
98
99
  subfolder: str = "",
99
100
  rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
@@ -103,6 +104,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
103
104
  self.model = models
104
105
  self.config = config
105
106
  self.rbln_config = rbln_config
107
+ if not rbln_config.is_frozen():
108
+ raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
109
+
106
110
  self.compiled_models = rbln_compiled_models
107
111
 
108
112
  # Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
@@ -118,7 +122,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
118
122
  else:
119
123
  self.generation_config = None
120
124
 
121
- # self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
122
125
  if self.generation_config is not None:
123
126
  self.generation_config.use_cache = True
124
127
 
@@ -181,11 +184,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
181
184
  return rbln_compiled_models
182
185
 
183
186
  @classmethod
184
- @use_rbln_config
185
187
  def _from_pretrained(
186
188
  cls,
187
189
  model_id: Union[str, Path],
188
- config: "PretrainedConfig" = None,
190
+ config: Optional["PretrainedConfig"] = None,
189
191
  use_auth_token: Optional[Union[bool, str]] = None,
190
192
  revision: Optional[str] = None,
191
193
  force_download: bool = False,
@@ -195,17 +197,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
195
197
  trust_remote_code: bool = False,
196
198
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
197
199
  # passed from compile function
198
- rbln_config: Optional[RBLNConfig] = None,
200
+ rbln_config: Optional[RBLNModelConfig] = None,
199
201
  rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
200
202
  rbln_submodules: List["RBLNBaseModel"] = [],
201
203
  **kwargs,
202
204
  ) -> "RBLNBaseModel":
203
- from_export_method = isinstance(rbln_config, RBLNConfig) and rbln_compiled_models is not None
204
-
205
- if not from_export_method:
206
- # from compiled dir
207
- rbln_kwargs = rbln_config or {}
208
-
205
+ if rbln_compiled_models is None:
209
206
  model_path_subfolder = cls._load_compiled_model_dir(
210
207
  model_id=model_id,
211
208
  use_auth_token=use_auth_token,
@@ -216,16 +213,34 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
216
213
  local_files_only=local_files_only,
217
214
  )
218
215
 
219
- rbln_config = RBLNConfig.load(model_path_subfolder)
220
- rbln_config.update_runtime_cfg(rbln_kwargs)
216
+ if isinstance(rbln_config, dict):
217
+ rbln_config_as_kwargs = {f"rbln_{key}": value for key, value in rbln_config.items()}
218
+ kwargs.update(rbln_config_as_kwargs)
219
+ rbln_config = None
220
+ elif isinstance(rbln_config, RBLNModelConfig) and rbln_config.rbln_model_cls_name != cls.__name__:
221
+ raise ValueError(
222
+ f"Cannot use the passed rbln_config. Its model class name ({rbln_config.rbln_model_cls_name}) "
223
+ f"does not match the expected model class name ({cls.__name__})."
224
+ )
225
+
226
+ rbln_config, kwargs = RBLNAutoConfig.load(
227
+ model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
228
+ )
221
229
 
222
- if rbln_config.meta["cls"] != cls.__name__:
230
+ if rbln_config.rbln_model_cls_name != cls.__name__:
223
231
  raise NameError(
224
232
  f"Cannot load the model. The model was originally compiled using "
225
- f"{rbln_config.meta['cls']}, but you are trying to load it with {cls.__name__}."
233
+ f"{rbln_config.rbln_model_cls_name}, but you are trying to load it with {cls.__name__}."
226
234
  "Please use the same model class that was used during compilation."
227
235
  )
228
236
 
237
+ if len(cls._rbln_submodules) > 0:
238
+ rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
239
+ else:
240
+ rbln_submodules = []
241
+
242
+ rbln_config.freeze()
243
+
229
244
  if config is None:
230
245
  if cls.hf_library_name == "transformers":
231
246
  config = AutoConfig.from_pretrained(
@@ -258,15 +273,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
258
273
 
259
274
  rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
260
275
 
261
- if len(cls._rbln_submodules) > 0:
262
- rbln_submodules = cls._load_submodules(
263
- model_save_dir=model_id,
264
- rbln_kwargs=rbln_kwargs,
265
- **kwargs,
266
- )
267
- else:
268
- rbln_submodules = []
269
-
270
276
  if subfolder != "":
271
277
  model_save_dir = Path(model_path_subfolder).absolute().parent
272
278
  else:
@@ -286,7 +292,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
286
292
  def _from_compiled_models(
287
293
  cls,
288
294
  rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
289
- rbln_config: RBLNConfig,
295
+ rbln_config: RBLNModelConfig,
290
296
  config: "PretrainedConfig",
291
297
  model_save_dir: Union[Path, str],
292
298
  subfolder: Union[Path, str],
@@ -303,16 +309,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
303
309
  # create runtimes only if `rbln_create_runtimes` is enabled
304
310
  try:
305
311
  models = (
306
- cls._create_runtimes(rbln_compiled_models, rbln_config.device_map, rbln_config.activate_profiler)
312
+ cls._create_runtimes(rbln_compiled_models, rbln_config)
307
313
  if rbln_config.create_runtimes
308
314
  else UnavailableRuntime()
309
315
  )
310
316
 
311
317
  except rebel.core.exception.RBLNRuntimeError as e:
312
- logger.warning(
313
- f"Failed to create the runtime for the model due to a runtime error: {e.__class__.__name__} - {e}"
318
+ error_msg = (
319
+ f"\nFailed to create RBLN runtime: {str(e)}\n\n"
320
+ f"If you only need to compile the model without loading it to NPU, you can use:\n"
321
+ f" from_pretrained(..., rbln_create_runtimes=False) or\n"
322
+ f" from_pretrained(..., rbln_config={{..., 'create_runtimes': False}})\n\n"
323
+ f"To check your NPU status, run the 'rbln-stat' command in your terminal.\n"
324
+ f"Make sure your NPU is properly installed and operational."
314
325
  )
315
- models = UnavailableRuntime()
326
+ raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
316
327
 
317
328
  return cls(
318
329
  models,
@@ -326,38 +337,31 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
326
337
  )
327
338
 
328
339
  @classmethod
329
- @use_rbln_config
330
- def _export(
331
- cls,
332
- model_id: Union[str, Path],
333
- rbln_config: Optional[Dict[str, Any]] = None,
334
- **kwargs,
335
- ) -> "RBLNBaseModel":
340
+ def _export(cls, model_id: Union[str, Path], **kwargs) -> "RBLNBaseModel":
336
341
  subfolder = kwargs.get("subfolder", "")
337
342
  model_save_dir = kwargs.pop("model_save_dir", None)
338
343
 
339
- rbln_kwargs = rbln_config
340
- model: "PreTrainedModel" = cls.get_pytorch_model(
341
- model_id=model_id,
342
- rbln_kwargs=rbln_kwargs,
343
- **kwargs,
344
- )
344
+ rbln_config, kwargs = cls.prepare_rbln_config(**kwargs)
345
+
346
+ model: "PreTrainedModel" = cls.get_pytorch_model(model_id=model_id, rbln_config=rbln_config, **kwargs)
345
347
  preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
346
348
  return cls.from_model(
347
- model,
348
- rbln_config=rbln_config,
349
- preprocessors=preprocessors,
350
- model_save_dir=model_save_dir,
351
- **kwargs,
349
+ model, preprocessors=preprocessors, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
352
350
  )
353
351
 
354
352
  @classmethod
355
- def from_pretrained(
356
- cls,
357
- model_id: Union[str, Path],
358
- export: bool = False,
359
- **kwargs,
360
- ) -> "RBLNBaseModel":
353
+ def prepare_rbln_config(
354
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
355
+ ) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
356
+ """
357
+ Extract rbln-config from kwargs and convert it to RBLNModelConfig.
358
+ """
359
+ config_cls = cls.get_rbln_config_class()
360
+ rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
361
+ return rbln_config, kwargs
362
+
363
+ @classmethod
364
+ def from_pretrained(cls, model_id: Union[str, Path], export: bool = False, **kwargs) -> "RBLNBaseModel":
361
365
  if isinstance(model_id, Path):
362
366
  model_id = model_id.as_posix()
363
367
  from_pretrained_method = cls._export if export else cls._from_pretrained
@@ -376,29 +380,26 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
376
380
  return compiled_model
377
381
 
378
382
  @classmethod
379
- def get_rbln_config(
380
- cls,
381
- rbln_kwargs: Dict[str, Any],
382
- **others,
383
- ) -> RBLNConfig:
384
- """
385
- Make default rbln-config for the model.
386
- kwargs for overriding model's config can be accepted.
387
- Note that batch_size should be specified with proper input_info.
388
- """
389
- rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
383
+ def update_rbln_config(cls, **others) -> RBLNModelConfig:
384
+ rbln_config = cls._update_rbln_config(**others)
385
+ rbln_config.freeze()
386
+ if rbln_config.rbln_model_cls_name != cls.__name__:
387
+ raise NameError(
388
+ f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
389
+ "This is an internal error. Please report it to the developers."
390
+ )
390
391
  return rbln_config
391
392
 
392
393
  @classmethod
393
394
  def get_hf_class(cls):
394
395
  """
395
- Lazily loads and caches the corresponding Hugging Face model class.
396
+ Lazily loads and caches the corresponding HuggingFace model class.
396
397
  Removes 'RBLN' prefix from the class name to get the original class name
397
398
  (e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
398
399
  the transformers/diffusers module.
399
400
 
400
401
  Returns:
401
- type: The original Hugging Face model class
402
+ type: The original HuggingFace model class
402
403
  """
403
404
  if cls._hf_class is None:
404
405
  hf_cls_name = cls.__name__[4:]
@@ -406,12 +407,42 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
406
407
  cls._hf_class = getattr(library, hf_cls_name, None)
407
408
  return cls._hf_class
408
409
 
410
+ @classmethod
411
+ def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
412
+ """
413
+ Lazily loads and caches the corresponding RBLN model config class.
414
+ """
415
+ if cls._rbln_config_class is None:
416
+ rbln_config_class_name = cls.__name__ + "Config"
417
+ library = importlib.import_module("optimum.rbln")
418
+ cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
419
+ if cls._rbln_config_class is None:
420
+ raise ValueError(
421
+ f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
422
+ "Please report it to the developers."
423
+ )
424
+ return cls._rbln_config_class
425
+
409
426
  def can_generate(self):
410
427
  return False
411
428
 
412
429
  def to(self, *args, **kwargs):
413
430
  return self
414
431
 
432
+ def parameters(self):
433
+ """
434
+ Provides a dummy parameter generator for compatibility.
435
+
436
+ This method mimics the interface of torch.nn.Module.parameters()
437
+ specifically for code that uses `next(model.parameters())` to infer
438
+ the device or dtype. It yields a single dummy tensor on CPU with float32 dtype.
439
+
440
+ Warning:
441
+ This does NOT yield the actual model parameters used by the RBLN runtime.
442
+ Code relying on iterating through all model parameters will not work as expected.
443
+ """
444
+ yield torch.tensor([1.0], dtype=torch.float32, device=torch.device("cpu"))
445
+
415
446
  def __call__(self, *args, **kwargs):
416
447
  return self.forward(*args, **kwargs)
417
448
 
@@ -448,7 +479,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
448
479
  save_directory (`Union[str, Path]`):
449
480
  Directory where to save the model file.
450
481
  push_to_hub (`bool`, *optional*, defaults to `False`):
451
- Whether or not to push your model to the Hugging Face model hub after saving it.
482
+ Whether or not to push your model to the HuggingFace model hub after saving it.
452
483
 
453
484
  """
454
485
  if os.path.isfile(save_directory):
@@ -481,11 +512,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
481
512
  # First copy everything to a temporary directory
482
513
  shutil.copytree(real_save_dir, tmp_dir)
483
514
 
484
- # Save configs to the temporary directory
485
- self.config.save_pretrained(tmp_dir)
486
- if self.generation_config is not None:
487
- self.generation_config.save_pretrained(tmp_dir)
488
-
489
515
  # If everything succeeded, atomically replace the target directory
490
516
  if os.path.exists(save_directory_path):
491
517
  shutil.rmtree(save_directory_path)
@@ -521,7 +547,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
521
547
 
522
548
  @classmethod
523
549
  @abstractmethod
524
- def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
550
+ def _update_rbln_config(cls, **rbln_config_kwargs) -> RBLNModelConfig:
525
551
  pass
526
552
 
527
553
  @classmethod
@@ -529,8 +555,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
529
555
  def _create_runtimes(
530
556
  cls,
531
557
  compiled_models: List[rebel.RBLNCompiledModel],
532
- rbln_device_map: Dict[str, int],
533
- activate_profiler: Optional[bool] = None,
558
+ rbln_config: RBLNModelConfig,
534
559
  ) -> List[rebel.Runtime]:
535
560
  # compiled_models -> runtimes
536
561
  pass
@@ -542,11 +567,11 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
542
567
 
543
568
  @classmethod
544
569
  @abstractmethod
545
- @use_rbln_config
546
570
  def from_model(
547
571
  cls,
548
572
  model: "PreTrainedModel",
549
- rbln_config: Dict[str, Any] = {},
573
+ config: Optional[PretrainedConfig] = None,
574
+ rbln_config: Optional[RBLNModelConfig] = None,
550
575
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
551
576
  subfolder: str = "",
552
577
  **kwargs,
@@ -12,10 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .attn import (
16
- register_rbln_custom_add_softmax_attention,
17
- register_rbln_custom_paged_attention,
18
- register_rbln_custom_paged_causal_attention,
19
- )
20
- from .flash_attn import register_rbln_custom_paged_flash_attention, register_rbln_custom_paged_flash_causal_attention
21
- from .kv_cache_update import register_rbln_custom_cache_update
15
+ from .attn import *
16
+ from .flash_attn import *
17
+ from .kv_cache_update import *
18
+ from .linear import linear