optimum-rbln 0.8.1rc1__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (119) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +5 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  18. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  19. optimum/rbln/diffusers/models/__init__.py +3 -13
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  24. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  25. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  26. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -28
  27. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  30. optimum/rbln/modeling.py +4 -5
  31. optimum/rbln/modeling_base.py +18 -14
  32. optimum/rbln/ops/kv_cache_update.py +5 -0
  33. optimum/rbln/ops/linear.py +7 -0
  34. optimum/rbln/transformers/__init__.py +60 -0
  35. optimum/rbln/transformers/configuration_generic.py +4 -4
  36. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  37. optimum/rbln/transformers/modeling_generic.py +1 -4
  38. optimum/rbln/transformers/models/__init__.py +45 -30
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  44. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  45. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  46. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  47. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  48. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  49. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  51. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  52. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  53. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  54. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  55. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  56. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  57. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  58. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  59. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  60. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  61. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  63. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  64. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  65. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  66. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  67. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  68. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  69. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  70. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  71. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  72. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  73. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  74. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  75. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  76. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  77. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  78. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  79. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  80. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  81. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  82. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  83. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  84. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  85. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  86. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  90. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  91. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  92. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  93. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  94. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  95. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  96. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  97. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  98. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  99. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  100. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  101. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  102. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  103. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  104. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  105. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  106. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  108. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  110. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  111. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  112. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  113. optimum/rbln/utils/depreacate_utils.py +16 -0
  114. optimum/rbln/utils/hub.py +8 -47
  115. optimum/rbln/utils/runtime_utils.py +31 -5
  116. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  117. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +119 -102
  118. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  119. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -279,7 +279,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
279
279
  tensor_type="pt",
280
280
  device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
281
281
  activate_profiler=rbln_config.activate_profiler,
282
- timeout=120,
282
+ timeout=rbln_config.timeout,
283
283
  )
284
284
  for compiled_model in compiled_models
285
285
  ]
@@ -63,11 +63,7 @@ if TYPE_CHECKING:
63
63
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
64
64
  RBLNStableDiffusionXLControlNetPipeline,
65
65
  )
66
- from .cosmos import (
67
- RBLNCosmosSafetyChecker,
68
- RBLNCosmosTextToWorldPipeline,
69
- RBLNCosmosVideoToWorldPipeline,
70
- )
66
+ from .cosmos import RBLNCosmosSafetyChecker, RBLNCosmosTextToWorldPipeline, RBLNCosmosVideoToWorldPipeline
71
67
  from .kandinsky2_2 import (
72
68
  RBLNKandinskyV22CombinedPipeline,
73
69
  RBLNKandinskyV22Img2ImgCombinedPipeline,
@@ -12,10 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
- from ....transformers import RBLNSiglipVisionModelConfig
18
+ from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
19
19
 
20
20
 
21
21
  class RBLNVideoSafetyModelConfig(RBLNModelConfig):
@@ -69,13 +69,21 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
69
69
  image_size: Optional[Tuple[int, int]] = None,
70
70
  height: Optional[int] = None,
71
71
  width: Optional[int] = None,
72
- **kwargs: Dict[str, Any],
72
+ **kwargs: Any,
73
73
  ):
74
74
  super().__init__(**kwargs)
75
75
  if height is not None and width is not None:
76
76
  image_size = (height, width)
77
77
 
78
- self.aegis = self.init_submodule_config(RBLNModelConfig, aegis)
78
+ tensor_parallel_size = kwargs.get("tensor_parallel_size")
79
+
80
+ self.aegis = self.init_submodule_config(
81
+ RBLNLlamaForCausalLMConfig,
82
+ aegis,
83
+ batch_size=batch_size,
84
+ tensor_parallel_size=tensor_parallel_size,
85
+ )
86
+
79
87
  self.siglip_encoder = self.init_submodule_config(
80
88
  RBLNSiglipVisionModelConfig,
81
89
  siglip_encoder,
@@ -127,24 +127,13 @@ class RBLNSigLIPEncoder(SigLIPEncoder):
127
127
 
128
128
  # We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
129
129
  self.model = RBLNSiglipVisionModel.from_pretrained(
130
- self.checkpoint_dir,
131
- rbln_device=rbln_config.siglip_encoder.device,
132
- rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
133
- rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
130
+ self.checkpoint_dir, rbln_config=rbln_config.siglip_encoder
134
131
  )
135
132
  else:
136
133
  super().__init__(model_name, checkpoint_id)
137
134
  model = self.model
138
135
  del self.model
139
- self.model = RBLNSiglipVisionModel.from_model(
140
- model,
141
- rbln_device=rbln_config.siglip_encoder.device,
142
- rbln_image_size=rbln_config.siglip_encoder.image_size,
143
- rbln_npu=rbln_config.siglip_encoder.npu,
144
- rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
145
- rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
146
- rbln_optimize_host_memory=rbln_config.siglip_encoder.optimize_host_memory,
147
- )
136
+ self.model = RBLNSiglipVisionModel.from_model(model, rbln_config=rbln_config.siglip_encoder)
148
137
  self.rbln_config = rbln_config
149
138
 
150
139
  # Override get_image_features to return pooler_output
@@ -335,27 +324,14 @@ class RBLNAegis(Aegis):
335
324
  torch.nn.Module.__init__(self)
336
325
  cache_dir = pathlib.Path(checkpoint_id) / "aegis"
337
326
  self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
338
- self.model = RBLNAutoModelForCausalLM.from_pretrained(
339
- cache_dir,
340
- rbln_device=rbln_config.aegis.device,
341
- rbln_create_runtimes=rbln_config.aegis.create_runtimes,
342
- rbln_activate_profiler=rbln_config.aegis.activate_profiler,
343
- )
327
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.aegis)
344
328
 
345
329
  else:
346
330
  super().__init__(checkpoint_id, base_model_id, aegis_adapter)
347
331
  model = self.model.merge_and_unload() # peft merge
348
332
  del self.model
349
333
 
350
- self.model = RBLNAutoModelForCausalLM.from_model(
351
- model,
352
- rbln_tensor_parallel_size=4,
353
- rbln_device=rbln_config.aegis.device,
354
- rbln_create_runtimes=rbln_config.aegis.create_runtimes,
355
- rbln_npu=rbln_config.aegis.npu,
356
- rbln_activate_profiler=rbln_config.aegis.activate_profiler,
357
- rbln_optimize_host_memory=rbln_config.aegis.optimize_host_memory,
358
- )
334
+ self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
359
335
 
360
336
  self.rbln_config = rbln_config
361
337
  self.dtype = torch.bfloat16
@@ -87,7 +87,7 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
92
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
93
  if safety_checker is None and export:
@@ -87,7 +87,7 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
92
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
93
  if safety_checker is None and export:
@@ -22,12 +22,7 @@ from diffusers import (
22
22
  UNet2DConditionModel,
23
23
  VQModel,
24
24
  )
25
- from transformers import (
26
- CLIPImageProcessor,
27
- CLIPTextModelWithProjection,
28
- CLIPTokenizer,
29
- CLIPVisionModelWithProjection,
30
- )
25
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
31
26
 
32
27
  from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
33
28
  from ...modeling_diffusers import RBLNDiffusionMixin
optimum/rbln/modeling.py CHANGED
@@ -35,8 +35,6 @@ logger = get_logger(__name__)
35
35
 
36
36
 
37
37
  class RBLNModel(RBLNBaseModel):
38
- _output_class = None
39
-
40
38
  @classmethod
41
39
  def update_kwargs(cls, kwargs):
42
40
  # Update user-given kwargs to get proper pytorch model.
@@ -80,7 +78,7 @@ class RBLNModel(RBLNBaseModel):
80
78
  rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
81
79
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
82
80
  subfolder: str = "",
83
- **kwargs: Dict[str, Any],
81
+ **kwargs: Any,
84
82
  ) -> "RBLNModel":
85
83
  """
86
84
  Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
@@ -238,11 +236,12 @@ class RBLNModel(RBLNBaseModel):
238
236
  tensor_type="pt",
239
237
  device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
240
238
  activate_profiler=rbln_config.activate_profiler,
239
+ timeout=rbln_config.timeout,
241
240
  )
242
241
  for compiled_model in compiled_models
243
242
  ]
244
243
 
245
- def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
244
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
246
245
  """
247
246
  Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
248
247
 
@@ -288,7 +287,7 @@ class RBLNModel(RBLNBaseModel):
288
287
  @classmethod
289
288
  def get_hf_output_class(cls):
290
289
  # Dynamically gets the output class from the corresponding HuggingFace model class.
291
- if cls._output_class:
290
+ if "_output_class" in cls.__dict__ and cls._output_class is not None:
292
291
  return cls._output_class
293
292
 
294
293
  hf_class = cls.get_hf_class()
@@ -23,9 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
23
23
  import rebel
24
24
  import torch
25
25
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
+ from transformers.utils.hub import PushToHubMixin
26
27
 
27
28
  from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
- from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
+ from .utils.hub import pull_compiled_model_from_hub, validate_files
29
30
  from .utils.logging import get_logger
30
31
  from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
31
32
  from .utils.save_utils import maybe_load_preprocessors
@@ -50,11 +51,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
50
51
  model_type = "rbln_model"
51
52
  auto_model_class = AutoModel
52
53
  config_class = AutoConfig
53
-
54
54
  config_name = "config.json"
55
55
  hf_library_name = "transformers"
56
- _hf_class = None
57
- _rbln_config_class = None
58
56
 
59
57
  def __init__(
60
58
  self,
@@ -115,7 +113,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
115
113
  def _load_compiled_model_dir(
116
114
  cls,
117
115
  model_id: Union[str, Path],
118
- use_auth_token: Optional[Union[bool, str]] = None,
116
+ token: Optional[Union[bool, str]] = None,
119
117
  revision: Optional[str] = None,
120
118
  force_download: bool = False,
121
119
  cache_dir: Optional[str] = None,
@@ -134,7 +132,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
134
132
  model_path = pull_compiled_model_from_hub(
135
133
  model_id=model_id,
136
134
  subfolder=subfolder,
137
- use_auth_token=use_auth_token,
135
+ token=token,
138
136
  revision=revision,
139
137
  cache_dir=cache_dir,
140
138
  force_download=force_download,
@@ -172,7 +170,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
172
170
  cls,
173
171
  model_id: Union[str, Path],
174
172
  config: Optional["PretrainedConfig"] = None,
175
- use_auth_token: Optional[Union[bool, str]] = None,
173
+ token: Optional[Union[bool, str]] = None,
176
174
  revision: Optional[str] = None,
177
175
  force_download: bool = False,
178
176
  cache_dir: Optional[str] = None,
@@ -189,7 +187,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
189
187
  if rbln_compiled_models is None:
190
188
  model_path_subfolder = cls._load_compiled_model_dir(
191
189
  model_id=model_id,
192
- use_auth_token=use_auth_token,
190
+ token=token,
193
191
  revision=revision,
194
192
  force_download=force_download,
195
193
  cache_dir=cache_dir,
@@ -232,7 +230,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
232
230
  cache_dir=cache_dir,
233
231
  force_download=force_download,
234
232
  revision=revision,
235
- token=use_auth_token,
233
+ token=token,
236
234
  trust_remote_code=trust_remote_code,
237
235
  )
238
236
  elif cls.hf_library_name == "diffusers":
@@ -250,7 +248,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
250
248
  force_download=force_download,
251
249
  local_files_only=local_files_only,
252
250
  revision=revision,
253
- token=use_auth_token,
251
+ token=token,
254
252
  subfolder=subfolder,
255
253
  )
256
254
  config = PretrainedConfig(**config)
@@ -350,7 +348,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
350
348
  model_id: Union[str, Path],
351
349
  export: bool = False,
352
350
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
353
- **kwargs: Dict[str, Any],
351
+ **kwargs: Any,
354
352
  ) -> "RBLNBaseModel":
355
353
  """
356
354
  The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
@@ -421,7 +419,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
421
419
 
422
420
  # Returns:
423
421
  # type: The original HuggingFace model class
424
- if cls._hf_class is None:
422
+ if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
425
423
  hf_cls_name = cls.__name__[4:]
426
424
  library = importlib.import_module(cls.hf_library_name)
427
425
  cls._hf_class = getattr(library, hf_cls_name, None)
@@ -430,7 +428,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
430
428
  @classmethod
431
429
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
432
430
  # Lazily loads and caches the corresponding RBLN model config class.
433
- if cls._rbln_config_class is None:
431
+ if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
434
432
  rbln_config_class_name = cls.__name__ + "Config"
435
433
  cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
436
434
  return cls._rbln_config_class
@@ -507,6 +505,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
507
505
  f"Please ensure the model directory exists and you have the necessary permissions to access it."
508
506
  )
509
507
 
508
+ if isinstance(self.config, PretrainedConfig):
509
+ self.config.save_pretrained(real_save_dir)
510
+
510
511
  if save_directory_path == real_save_dir:
511
512
  raise FileExistsError(
512
513
  f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
@@ -534,7 +535,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
534
535
  raise e # Re-raise the exception after cleanup
535
536
 
536
537
  if push_to_hub:
537
- return super().push_to_hub(str(save_directory_path), **kwargs)
538
+ repo_id = kwargs.pop("repo_id", None)
539
+ if repo_id is None:
540
+ raise ValueError("`repo_id` must be provided to push the model to the HuggingFace model hub.")
541
+ return super().push_to_hub(repo_id=repo_id, **kwargs)
538
542
 
539
543
  @staticmethod
540
544
  def _raise_missing_compiled_file_error(missing_files: List[str]):
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
22
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
23
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
24
24
  return torch.empty_like(cache)
25
+
26
+
27
+ @rbln_cache_update.register_fake
28
+ def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
+ return torch.empty_like(cache)
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
23
23
  output_shape = list(input.shape[:-1])
24
24
  output_shape += [weight.shape[0]]
25
25
  return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
26
+
27
+
28
+ @linear.register_fake
29
+ def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
30
+ output_shape = list(input.shape[:-1])
31
+ output_shape += [weight.shape[0]]
32
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -62,12 +62,16 @@ _import_structure = {
62
62
  "RBLNCLIPVisionModelWithProjectionConfig",
63
63
  "RBLNDecoderOnlyModelForCausalLM",
64
64
  "RBLNDecoderOnlyModelForCausalLMConfig",
65
+ "RBLNDecoderOnlyModelConfig",
66
+ "RBLNDecoderOnlyModel",
65
67
  "RBLNDistilBertForQuestionAnswering",
66
68
  "RBLNDistilBertForQuestionAnsweringConfig",
67
69
  "RBLNDPTForDepthEstimation",
68
70
  "RBLNDPTForDepthEstimationConfig",
69
71
  "RBLNExaoneForCausalLM",
70
72
  "RBLNExaoneForCausalLMConfig",
73
+ "RBLNGemmaModel",
74
+ "RBLNGemmaModelConfig",
71
75
  "RBLNGemma3ForCausalLM",
72
76
  "RBLNGemma3ForCausalLMConfig",
73
77
  "RBLNGemma3ForConditionalGeneration",
@@ -76,28 +80,54 @@ _import_structure = {
76
80
  "RBLNGemmaForCausalLMConfig",
77
81
  "RBLNGPT2LMHeadModel",
78
82
  "RBLNGPT2LMHeadModelConfig",
83
+ "RBLNGPT2Model",
84
+ "RBLNGPT2ModelConfig",
79
85
  "RBLNIdefics3ForConditionalGeneration",
80
86
  "RBLNIdefics3ForConditionalGenerationConfig",
81
87
  "RBLNIdefics3VisionTransformer",
82
88
  "RBLNIdefics3VisionTransformerConfig",
83
89
  "RBLNLlamaForCausalLM",
84
90
  "RBLNLlamaForCausalLMConfig",
91
+ "RBLNLlavaForConditionalGeneration",
92
+ "RBLNLlavaForConditionalGenerationConfig",
93
+ "RBLNLlamaModel",
94
+ "RBLNLlamaModelConfig",
95
+ "RBLNOPTForCausalLM",
96
+ "RBLNOPTForCausalLMConfig",
97
+ "RBLNPegasusForConditionalGeneration",
98
+ "RBLNPegasusForConditionalGenerationConfig",
99
+ "RBLNPegasusModel",
100
+ "RBLNPegasusModelConfig",
85
101
  "RBLNLlavaNextForConditionalGeneration",
86
102
  "RBLNLlavaNextForConditionalGenerationConfig",
87
103
  "RBLNMidmLMHeadModel",
88
104
  "RBLNMidmLMHeadModelConfig",
89
105
  "RBLNMistralForCausalLM",
90
106
  "RBLNMistralForCausalLMConfig",
107
+ "RBLNMistralModel",
108
+ "RBLNMistralModelConfig",
91
109
  "RBLNOPTForCausalLM",
92
110
  "RBLNOPTForCausalLMConfig",
111
+ "RBLNOPTModel",
112
+ "RBLNOPTModelConfig",
93
113
  "RBLNPhiForCausalLM",
94
114
  "RBLNPhiForCausalLMConfig",
115
+ "RBLNPixtralVisionModelConfig",
116
+ "RBLNPixtralVisionModel",
117
+ "RBLNPhiModel",
118
+ "RBLNPhiModelConfig",
95
119
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
96
120
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
97
121
  "RBLNQwen2_5_VLForConditionalGeneration",
98
122
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
123
+ "RBLNQwen2Model",
124
+ "RBLNQwen2ModelConfig",
99
125
  "RBLNQwen2ForCausalLM",
100
126
  "RBLNQwen2ForCausalLMConfig",
127
+ "RBLNQwen3ForCausalLM",
128
+ "RBLNQwen3ForCausalLMConfig",
129
+ "RBLNQwen3Model",
130
+ "RBLNQwen3ModelConfig",
101
131
  "RBLNResNetForImageClassification",
102
132
  "RBLNResNetForImageClassificationConfig",
103
133
  "RBLNRobertaForMaskedLM",
@@ -166,6 +196,10 @@ if TYPE_CHECKING:
166
196
  RBLNCLIPVisionModelConfig,
167
197
  RBLNCLIPVisionModelWithProjection,
168
198
  RBLNCLIPVisionModelWithProjectionConfig,
199
+ RBLNColPaliForRetrieval,
200
+ RBLNColPaliForRetrievalConfig,
201
+ RBLNDecoderOnlyModel,
202
+ RBLNDecoderOnlyModelConfig,
169
203
  RBLNDecoderOnlyModelForCausalLM,
170
204
  RBLNDecoderOnlyModelForCausalLMConfig,
171
205
  RBLNDistilBertForQuestionAnswering,
@@ -180,30 +214,56 @@ if TYPE_CHECKING:
180
214
  RBLNGemma3ForConditionalGenerationConfig,
181
215
  RBLNGemmaForCausalLM,
182
216
  RBLNGemmaForCausalLMConfig,
217
+ RBLNGemmaModel,
218
+ RBLNGemmaModelConfig,
183
219
  RBLNGPT2LMHeadModel,
184
220
  RBLNGPT2LMHeadModelConfig,
221
+ RBLNGPT2Model,
222
+ RBLNGPT2ModelConfig,
185
223
  RBLNIdefics3ForConditionalGeneration,
186
224
  RBLNIdefics3ForConditionalGenerationConfig,
187
225
  RBLNIdefics3VisionTransformer,
188
226
  RBLNIdefics3VisionTransformerConfig,
189
227
  RBLNLlamaForCausalLM,
190
228
  RBLNLlamaForCausalLMConfig,
229
+ RBLNLlamaModel,
230
+ RBLNLlamaModelConfig,
231
+ RBLNLlavaForConditionalGeneration,
232
+ RBLNLlavaForConditionalGenerationConfig,
191
233
  RBLNLlavaNextForConditionalGeneration,
192
234
  RBLNLlavaNextForConditionalGenerationConfig,
193
235
  RBLNMidmLMHeadModel,
194
236
  RBLNMidmLMHeadModelConfig,
195
237
  RBLNMistralForCausalLM,
196
238
  RBLNMistralForCausalLMConfig,
239
+ RBLNMistralModel,
240
+ RBLNMistralModelConfig,
197
241
  RBLNOPTForCausalLM,
198
242
  RBLNOPTForCausalLMConfig,
243
+ RBLNOPTModel,
244
+ RBLNOPTModelConfig,
245
+ RBLNPegasusForConditionalGeneration,
246
+ RBLNPegasusForConditionalGenerationConfig,
247
+ RBLNPegasusModel,
248
+ RBLNPegasusModelConfig,
199
249
  RBLNPhiForCausalLM,
200
250
  RBLNPhiForCausalLMConfig,
251
+ RBLNPhiModel,
252
+ RBLNPhiModelConfig,
253
+ RBLNPixtralVisionModel,
254
+ RBLNPixtralVisionModelConfig,
201
255
  RBLNQwen2_5_VisionTransformerPretrainedModel,
202
256
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
203
257
  RBLNQwen2_5_VLForConditionalGeneration,
204
258
  RBLNQwen2_5_VLForConditionalGenerationConfig,
205
259
  RBLNQwen2ForCausalLM,
206
260
  RBLNQwen2ForCausalLMConfig,
261
+ RBLNQwen2Model,
262
+ RBLNQwen2ModelConfig,
263
+ RBLNQwen3ForCausalLM,
264
+ RBLNQwen3ForCausalLMConfig,
265
+ RBLNQwen3Model,
266
+ RBLNQwen3ModelConfig,
207
267
  RBLNResNetForImageClassification,
208
268
  RBLNResNetForImageClassificationConfig,
209
269
  RBLNRobertaForMaskedLM,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from typing import Any, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
@@ -25,7 +25,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs: Dict[str, Any],
28
+ **kwargs: Any,
29
29
  ):
30
30
  """
31
31
  Args:
@@ -52,7 +52,7 @@ class RBLNImageModelConfig(RBLNModelConfig):
52
52
  self,
53
53
  image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
54
  batch_size: Optional[int] = None,
55
- **kwargs: Dict[str, Any],
55
+ **kwargs: Any,
56
56
  ):
57
57
  """
58
58
  Args:
@@ -124,7 +124,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
124
124
  batch_size: Optional[int] = None,
125
125
  max_length: Optional[int] = None,
126
126
  num_mel_bins: Optional[int] = None,
127
- **kwargs: Dict[str, Any],
127
+ **kwargs: Any,
128
128
  ):
129
129
  """
130
130
  Args: