optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (120) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  23. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  24. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  25. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  31. optimum/rbln/modeling.py +4 -5
  32. optimum/rbln/modeling_base.py +18 -14
  33. optimum/rbln/ops/kv_cache_update.py +5 -0
  34. optimum/rbln/ops/linear.py +7 -0
  35. optimum/rbln/transformers/__init__.py +60 -0
  36. optimum/rbln/transformers/configuration_generic.py +4 -4
  37. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  38. optimum/rbln/transformers/modeling_generic.py +1 -4
  39. optimum/rbln/transformers/models/__init__.py +45 -30
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  41. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  42. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  43. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  44. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  45. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  46. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  47. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  48. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  51. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  52. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  53. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  54. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  55. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  56. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  57. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  58. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  59. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  60. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  61. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  62. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  63. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  64. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  65. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  66. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  67. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  68. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  69. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  75. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  76. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  77. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  78. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  79. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  80. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  81. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  82. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  83. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  84. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  85. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  86. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  87. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  91. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  92. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  93. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  94. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  97. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  101. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  102. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  103. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  104. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  105. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  106. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  107. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  108. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  110. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  111. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  112. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  113. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  114. optimum/rbln/utils/depreacate_utils.py +16 -0
  115. optimum/rbln/utils/hub.py +8 -47
  116. optimum/rbln/utils/runtime_utils.py +31 -5
  117. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  118. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
  119. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  120. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -70,8 +70,6 @@ class RBLNDiffusionMixin:
70
70
  _submodules = []
71
71
  _optional_submodules = []
72
72
  _prefix = {}
73
- _rbln_config_class = None
74
- _hf_class = None
75
73
 
76
74
  @staticmethod
77
75
  def _maybe_apply_and_fuse_lora(
@@ -114,14 +112,14 @@ class RBLNDiffusionMixin:
114
112
  @classmethod
115
113
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
116
114
  # Lazily loads and caches the corresponding RBLN model config class.
117
- if cls._rbln_config_class is None:
115
+ if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
118
116
  rbln_config_class_name = cls.__name__ + "Config"
119
117
  cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
120
118
  return cls._rbln_config_class
121
119
 
122
120
  @classmethod
123
121
  def get_hf_class(cls):
124
- if cls._hf_class is None:
122
+ if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
125
123
  hf_cls_name = cls.__name__[4:]
126
124
  library = importlib.import_module("diffusers")
127
125
  cls._hf_class = getattr(library, hf_cls_name, None)
@@ -138,7 +136,7 @@ class RBLNDiffusionMixin:
138
136
  lora_ids: Optional[Union[str, List[str]]] = None,
139
137
  lora_weights_names: Optional[Union[str, List[str]]] = None,
140
138
  lora_scales: Optional[Union[float, List[float]]] = None,
141
- **kwargs: Dict[str, Any],
139
+ **kwargs: Any,
142
140
  ) -> "RBLNDiffusionMixin":
143
141
  """
144
142
  Load a pretrained diffusion pipeline from a model checkpoint, with optional compilation for RBLN NPUs.
@@ -230,6 +228,7 @@ class RBLNDiffusionMixin:
230
228
  create_runtimes=rbln_config.create_runtimes,
231
229
  optimize_host_mem=rbln_config.optimize_host_memory,
232
230
  activate_profiler=rbln_config.activate_profiler,
231
+ timeout=rbln_config.timeout,
233
232
  ):
234
233
  model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
235
234
 
@@ -35,20 +35,10 @@ _import_structure = {
35
35
  }
36
36
 
37
37
  if TYPE_CHECKING:
38
- from .autoencoders import (
39
- RBLNAutoencoderKL,
40
- RBLNAutoencoderKLCosmos,
41
- RBLNVQModel,
42
- )
38
+ from .autoencoders import RBLNAutoencoderKL, RBLNAutoencoderKLCosmos, RBLNVQModel
43
39
  from .controlnet import RBLNControlNetModel
44
- from .transformers import (
45
- RBLNCosmosTransformer3DModel,
46
- RBLNPriorTransformer,
47
- RBLNSD3Transformer2DModel,
48
- )
49
- from .unets import (
50
- RBLNUNet2DConditionModel,
51
- )
40
+ from .transformers import RBLNCosmosTransformer3DModel, RBLNPriorTransformer, RBLNSD3Transformer2DModel
41
+ from .unets import RBLNUNet2DConditionModel
52
42
  else:
53
43
  import sys
54
44
 
@@ -209,6 +209,7 @@ class RBLNAutoencoderKL(RBLNModel):
209
209
  tensor_type="pt",
210
210
  device=device_val,
211
211
  activate_profiler=rbln_config.activate_profiler,
212
+ timeout=rbln_config.timeout,
212
213
  )
213
214
  for compiled_model, device_val in zip(compiled_models, device_vals)
214
215
  ]
@@ -200,6 +200,7 @@ class RBLNAutoencoderKLCosmos(RBLNModel):
200
200
  tensor_type="pt",
201
201
  device=device_val,
202
202
  activate_profiler=rbln_config.activate_profiler,
203
+ timeout=rbln_config.timeout,
203
204
  )
204
205
  for compiled_model, device_val in zip(compiled_models, device_vals)
205
206
  ]
@@ -165,6 +165,7 @@ class RBLNVQModel(RBLNModel):
165
165
  tensor_type="pt",
166
166
  device=device_val,
167
167
  activate_profiler=rbln_config.activate_profiler,
168
+ timeout=rbln_config.timeout,
168
169
  )
169
170
  for compiled_model, device_val in zip(compiled_models, device_vals)
170
171
  ]
@@ -279,7 +279,7 @@ class RBLNCosmosTransformer3DModel(RBLNModel):
279
279
  tensor_type="pt",
280
280
  device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
281
281
  activate_profiler=rbln_config.activate_profiler,
282
- timeout=120,
282
+ timeout=rbln_config.timeout,
283
283
  )
284
284
  for compiled_model in compiled_models
285
285
  ]
@@ -63,11 +63,7 @@ if TYPE_CHECKING:
63
63
  RBLNStableDiffusionXLControlNetImg2ImgPipeline,
64
64
  RBLNStableDiffusionXLControlNetPipeline,
65
65
  )
66
- from .cosmos import (
67
- RBLNCosmosSafetyChecker,
68
- RBLNCosmosTextToWorldPipeline,
69
- RBLNCosmosVideoToWorldPipeline,
70
- )
66
+ from .cosmos import RBLNCosmosSafetyChecker, RBLNCosmosTextToWorldPipeline, RBLNCosmosVideoToWorldPipeline
71
67
  from .kandinsky2_2 import (
72
68
  RBLNKandinskyV22CombinedPipeline,
73
69
  RBLNKandinskyV22Img2ImgCombinedPipeline,
@@ -12,10 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional, Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  from ....configuration_utils import RBLNAutoConfig, RBLNModelConfig
18
- from ....transformers import RBLNSiglipVisionModelConfig
18
+ from ....transformers import RBLNLlamaForCausalLMConfig, RBLNSiglipVisionModelConfig
19
19
 
20
20
 
21
21
  class RBLNVideoSafetyModelConfig(RBLNModelConfig):
@@ -69,13 +69,21 @@ class RBLNCosmosSafetyCheckerConfig(RBLNModelConfig):
69
69
  image_size: Optional[Tuple[int, int]] = None,
70
70
  height: Optional[int] = None,
71
71
  width: Optional[int] = None,
72
- **kwargs: Dict[str, Any],
72
+ **kwargs: Any,
73
73
  ):
74
74
  super().__init__(**kwargs)
75
75
  if height is not None and width is not None:
76
76
  image_size = (height, width)
77
77
 
78
- self.aegis = self.init_submodule_config(RBLNModelConfig, aegis)
78
+ tensor_parallel_size = kwargs.get("tensor_parallel_size")
79
+
80
+ self.aegis = self.init_submodule_config(
81
+ RBLNLlamaForCausalLMConfig,
82
+ aegis,
83
+ batch_size=batch_size,
84
+ tensor_parallel_size=tensor_parallel_size,
85
+ )
86
+
79
87
  self.siglip_encoder = self.init_submodule_config(
80
88
  RBLNSiglipVisionModelConfig,
81
89
  siglip_encoder,
@@ -127,23 +127,13 @@ class RBLNSigLIPEncoder(SigLIPEncoder):
127
127
 
128
128
  # We don't use RBLNSiglipModel, but we need to override get_image_features to return pooler_output
129
129
  self.model = RBLNSiglipVisionModel.from_pretrained(
130
- self.checkpoint_dir,
131
- rbln_device=rbln_config.siglip_encoder.device,
132
- rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
133
- rbln_activate_profiler=rbln_config.aegis.activate_profiler,
130
+ self.checkpoint_dir, rbln_config=rbln_config.siglip_encoder
134
131
  )
135
132
  else:
136
133
  super().__init__(model_name, checkpoint_id)
137
134
  model = self.model
138
135
  del self.model
139
- self.model = RBLNSiglipVisionModel.from_model(
140
- model,
141
- rbln_device=rbln_config.siglip_encoder.device,
142
- rbln_image_size=rbln_config.siglip_encoder.image_size,
143
- rbln_npu=rbln_config.siglip_encoder.npu,
144
- rbln_create_runtimes=rbln_config.siglip_encoder.create_runtimes,
145
- rbln_activate_profiler=rbln_config.siglip_encoder.activate_profiler,
146
- )
136
+ self.model = RBLNSiglipVisionModel.from_model(model, rbln_config=rbln_config.siglip_encoder)
147
137
  self.rbln_config = rbln_config
148
138
 
149
139
  # Override get_image_features to return pooler_output
@@ -334,26 +324,14 @@ class RBLNAegis(Aegis):
334
324
  torch.nn.Module.__init__(self)
335
325
  cache_dir = pathlib.Path(checkpoint_id) / "aegis"
336
326
  self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
337
- self.model = RBLNAutoModelForCausalLM.from_pretrained(
338
- cache_dir,
339
- rbln_device=rbln_config.aegis.device,
340
- rbln_create_runtimes=rbln_config.aegis.create_runtimes,
341
- rbln_activate_profiler=rbln_config.aegis.activate_profiler,
342
- )
327
+ self.model = RBLNAutoModelForCausalLM.from_pretrained(cache_dir, rbln_config=rbln_config.aegis)
343
328
 
344
329
  else:
345
330
  super().__init__(checkpoint_id, base_model_id, aegis_adapter)
346
331
  model = self.model.merge_and_unload() # peft merge
347
332
  del self.model
348
333
 
349
- self.model = RBLNAutoModelForCausalLM.from_model(
350
- model,
351
- rbln_tensor_parallel_size=4,
352
- rbln_device=rbln_config.aegis.device,
353
- rbln_create_runtimes=rbln_config.aegis.create_runtimes,
354
- rbln_npu=rbln_config.aegis.npu,
355
- rbln_activate_profiler=rbln_config.aegis.activate_profiler,
356
- )
334
+ self.model = RBLNAutoModelForCausalLM.from_model(model, rbln_config=rbln_config.aegis)
357
335
 
358
336
  self.rbln_config = rbln_config
359
337
  self.dtype = torch.bfloat16
@@ -35,7 +35,7 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
35
35
  RBLN-accelerated implementation of Cosmos Text to World pipeline for text-to-video generation.
36
36
 
37
37
  This pipeline compiles Cosmos Text to World models to run efficiently on RBLN NPUs, enabling high-performance
38
- inference for generating images with distinctive artistic style and enhanced visual quality.
38
+ inference for generating videos with distinctive artistic style and enhanced visual quality.
39
39
  """
40
40
 
41
41
  original_class = CosmosTextToWorldPipeline
@@ -87,7 +87,7 @@ class RBLNCosmosTextToWorldPipeline(RBLNDiffusionMixin, CosmosTextToWorldPipelin
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
92
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
93
  if safety_checker is None and export:
@@ -35,7 +35,7 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
35
35
  RBLN-accelerated implementation of Cosmos Video to World pipeline for video-to-video generation.
36
36
 
37
37
  This pipeline compiles Cosmos Video to World models to run efficiently on RBLN NPUs, enabling high-performance
38
- inference for generating images with distinctive artistic style and enhanced visual quality.
38
+ inference for generating videos with distinctive artistic style and enhanced visual quality.
39
39
  """
40
40
 
41
41
  original_class = CosmosVideoToWorldPipeline
@@ -87,7 +87,7 @@ class RBLNCosmosVideoToWorldPipeline(RBLNDiffusionMixin, CosmosVideoToWorldPipel
87
87
  export: bool = False,
88
88
  safety_checker: Optional[RBLNCosmosSafetyChecker] = None,
89
89
  rbln_config: Dict[str, Any] = {},
90
- **kwargs: Dict[str, Any],
90
+ **kwargs: Any,
91
91
  ):
92
92
  rbln_config, kwargs = cls.get_rbln_config_class().initialize_from_kwargs(rbln_config, **kwargs)
93
93
  if safety_checker is None and export:
@@ -22,12 +22,7 @@ from diffusers import (
22
22
  UNet2DConditionModel,
23
23
  VQModel,
24
24
  )
25
- from transformers import (
26
- CLIPImageProcessor,
27
- CLIPTextModelWithProjection,
28
- CLIPTokenizer,
29
- CLIPVisionModelWithProjection,
30
- )
25
+ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
31
26
 
32
27
  from ...configurations import RBLNKandinskyV22CombinedPipelineConfig
33
28
  from ...modeling_diffusers import RBLNDiffusionMixin
optimum/rbln/modeling.py CHANGED
@@ -35,8 +35,6 @@ logger = get_logger(__name__)
35
35
 
36
36
 
37
37
  class RBLNModel(RBLNBaseModel):
38
- _output_class = None
39
-
40
38
  @classmethod
41
39
  def update_kwargs(cls, kwargs):
42
40
  # Update user-given kwargs to get proper pytorch model.
@@ -80,7 +78,7 @@ class RBLNModel(RBLNBaseModel):
80
78
  rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
81
79
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
82
80
  subfolder: str = "",
83
- **kwargs: Dict[str, Any],
81
+ **kwargs: Any,
84
82
  ) -> "RBLNModel":
85
83
  """
86
84
  Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
@@ -238,11 +236,12 @@ class RBLNModel(RBLNBaseModel):
238
236
  tensor_type="pt",
239
237
  device=rbln_config.device_map[DEFAULT_COMPILED_MODEL_NAME],
240
238
  activate_profiler=rbln_config.activate_profiler,
239
+ timeout=rbln_config.timeout,
241
240
  )
242
241
  for compiled_model in compiled_models
243
242
  ]
244
243
 
245
- def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Dict[str, Any]) -> Any:
244
+ def forward(self, *args: Any, return_dict: Optional[bool] = None, **kwargs: Any) -> Any:
246
245
  """
247
246
  Defines the forward pass of the RBLN model, providing a drop-in replacement for HuggingFace PreTrainedModel.
248
247
 
@@ -288,7 +287,7 @@ class RBLNModel(RBLNBaseModel):
288
287
  @classmethod
289
288
  def get_hf_output_class(cls):
290
289
  # Dynamically gets the output class from the corresponding HuggingFace model class.
291
- if cls._output_class:
290
+ if "_output_class" in cls.__dict__ and cls._output_class is not None:
292
291
  return cls._output_class
293
292
 
294
293
  hf_class = cls.get_hf_class()
@@ -23,9 +23,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
23
23
  import rebel
24
24
  import torch
25
25
  from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
+ from transformers.utils.hub import PushToHubMixin
26
27
 
27
28
  from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig, get_rbln_config_class
28
- from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
29
+ from .utils.hub import pull_compiled_model_from_hub, validate_files
29
30
  from .utils.logging import get_logger
30
31
  from .utils.runtime_utils import UnavailableRuntime, tp_and_devices_are_ok
31
32
  from .utils.save_utils import maybe_load_preprocessors
@@ -50,11 +51,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
50
51
  model_type = "rbln_model"
51
52
  auto_model_class = AutoModel
52
53
  config_class = AutoConfig
53
-
54
54
  config_name = "config.json"
55
55
  hf_library_name = "transformers"
56
- _hf_class = None
57
- _rbln_config_class = None
58
56
 
59
57
  def __init__(
60
58
  self,
@@ -115,7 +113,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
115
113
  def _load_compiled_model_dir(
116
114
  cls,
117
115
  model_id: Union[str, Path],
118
- use_auth_token: Optional[Union[bool, str]] = None,
116
+ token: Optional[Union[bool, str]] = None,
119
117
  revision: Optional[str] = None,
120
118
  force_download: bool = False,
121
119
  cache_dir: Optional[str] = None,
@@ -134,7 +132,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
134
132
  model_path = pull_compiled_model_from_hub(
135
133
  model_id=model_id,
136
134
  subfolder=subfolder,
137
- use_auth_token=use_auth_token,
135
+ token=token,
138
136
  revision=revision,
139
137
  cache_dir=cache_dir,
140
138
  force_download=force_download,
@@ -172,7 +170,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
172
170
  cls,
173
171
  model_id: Union[str, Path],
174
172
  config: Optional["PretrainedConfig"] = None,
175
- use_auth_token: Optional[Union[bool, str]] = None,
173
+ token: Optional[Union[bool, str]] = None,
176
174
  revision: Optional[str] = None,
177
175
  force_download: bool = False,
178
176
  cache_dir: Optional[str] = None,
@@ -189,7 +187,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
189
187
  if rbln_compiled_models is None:
190
188
  model_path_subfolder = cls._load_compiled_model_dir(
191
189
  model_id=model_id,
192
- use_auth_token=use_auth_token,
190
+ token=token,
193
191
  revision=revision,
194
192
  force_download=force_download,
195
193
  cache_dir=cache_dir,
@@ -232,7 +230,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
232
230
  cache_dir=cache_dir,
233
231
  force_download=force_download,
234
232
  revision=revision,
235
- token=use_auth_token,
233
+ token=token,
236
234
  trust_remote_code=trust_remote_code,
237
235
  )
238
236
  elif cls.hf_library_name == "diffusers":
@@ -250,7 +248,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
250
248
  force_download=force_download,
251
249
  local_files_only=local_files_only,
252
250
  revision=revision,
253
- token=use_auth_token,
251
+ token=token,
254
252
  subfolder=subfolder,
255
253
  )
256
254
  config = PretrainedConfig(**config)
@@ -350,7 +348,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
350
348
  model_id: Union[str, Path],
351
349
  export: bool = False,
352
350
  rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
353
- **kwargs: Dict[str, Any],
351
+ **kwargs: Any,
354
352
  ) -> "RBLNBaseModel":
355
353
  """
356
354
  The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
@@ -421,7 +419,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
421
419
 
422
420
  # Returns:
423
421
  # type: The original HuggingFace model class
424
- if cls._hf_class is None:
422
+ if "_hf_class" not in cls.__dict__ or cls._hf_class is None:
425
423
  hf_cls_name = cls.__name__[4:]
426
424
  library = importlib.import_module(cls.hf_library_name)
427
425
  cls._hf_class = getattr(library, hf_cls_name, None)
@@ -430,7 +428,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
430
428
  @classmethod
431
429
  def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
432
430
  # Lazily loads and caches the corresponding RBLN model config class.
433
- if cls._rbln_config_class is None:
431
+ if "_rbln_config_class" not in cls.__dict__ or cls._rbln_config_class is None:
434
432
  rbln_config_class_name = cls.__name__ + "Config"
435
433
  cls._rbln_config_class = get_rbln_config_class(rbln_config_class_name)
436
434
  return cls._rbln_config_class
@@ -507,6 +505,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
507
505
  f"Please ensure the model directory exists and you have the necessary permissions to access it."
508
506
  )
509
507
 
508
+ if isinstance(self.config, PretrainedConfig):
509
+ self.config.save_pretrained(real_save_dir)
510
+
510
511
  if save_directory_path == real_save_dir:
511
512
  raise FileExistsError(
512
513
  f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
@@ -534,7 +535,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
534
535
  raise e # Re-raise the exception after cleanup
535
536
 
536
537
  if push_to_hub:
537
- return super().push_to_hub(str(save_directory_path), **kwargs)
538
+ repo_id = kwargs.pop("repo_id", None)
539
+ if repo_id is None:
540
+ raise ValueError("`repo_id` must be provided to push the model to the HuggingFace model hub.")
541
+ return super().push_to_hub(repo_id=repo_id, **kwargs)
538
542
 
539
543
  @staticmethod
540
544
  def _raise_missing_compiled_file_error(missing_files: List[str]):
@@ -22,3 +22,8 @@ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tens
22
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
23
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
24
24
  return torch.empty_like(cache)
25
+
26
+
27
+ @rbln_cache_update.register_fake
28
+ def rbln_cache_update_fake(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
+ return torch.empty_like(cache)
@@ -23,3 +23,10 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
23
23
  output_shape = list(input.shape[:-1])
24
24
  output_shape += [weight.shape[0]]
25
25
  return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
26
+
27
+
28
+ @linear.register_fake
29
+ def linear_fake(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
30
+ output_shape = list(input.shape[:-1])
31
+ output_shape += [weight.shape[0]]
32
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -62,12 +62,16 @@ _import_structure = {
62
62
  "RBLNCLIPVisionModelWithProjectionConfig",
63
63
  "RBLNDecoderOnlyModelForCausalLM",
64
64
  "RBLNDecoderOnlyModelForCausalLMConfig",
65
+ "RBLNDecoderOnlyModelConfig",
66
+ "RBLNDecoderOnlyModel",
65
67
  "RBLNDistilBertForQuestionAnswering",
66
68
  "RBLNDistilBertForQuestionAnsweringConfig",
67
69
  "RBLNDPTForDepthEstimation",
68
70
  "RBLNDPTForDepthEstimationConfig",
69
71
  "RBLNExaoneForCausalLM",
70
72
  "RBLNExaoneForCausalLMConfig",
73
+ "RBLNGemmaModel",
74
+ "RBLNGemmaModelConfig",
71
75
  "RBLNGemma3ForCausalLM",
72
76
  "RBLNGemma3ForCausalLMConfig",
73
77
  "RBLNGemma3ForConditionalGeneration",
@@ -76,28 +80,54 @@ _import_structure = {
76
80
  "RBLNGemmaForCausalLMConfig",
77
81
  "RBLNGPT2LMHeadModel",
78
82
  "RBLNGPT2LMHeadModelConfig",
83
+ "RBLNGPT2Model",
84
+ "RBLNGPT2ModelConfig",
79
85
  "RBLNIdefics3ForConditionalGeneration",
80
86
  "RBLNIdefics3ForConditionalGenerationConfig",
81
87
  "RBLNIdefics3VisionTransformer",
82
88
  "RBLNIdefics3VisionTransformerConfig",
83
89
  "RBLNLlamaForCausalLM",
84
90
  "RBLNLlamaForCausalLMConfig",
91
+ "RBLNLlavaForConditionalGeneration",
92
+ "RBLNLlavaForConditionalGenerationConfig",
93
+ "RBLNLlamaModel",
94
+ "RBLNLlamaModelConfig",
95
+ "RBLNOPTForCausalLM",
96
+ "RBLNOPTForCausalLMConfig",
97
+ "RBLNPegasusForConditionalGeneration",
98
+ "RBLNPegasusForConditionalGenerationConfig",
99
+ "RBLNPegasusModel",
100
+ "RBLNPegasusModelConfig",
85
101
  "RBLNLlavaNextForConditionalGeneration",
86
102
  "RBLNLlavaNextForConditionalGenerationConfig",
87
103
  "RBLNMidmLMHeadModel",
88
104
  "RBLNMidmLMHeadModelConfig",
89
105
  "RBLNMistralForCausalLM",
90
106
  "RBLNMistralForCausalLMConfig",
107
+ "RBLNMistralModel",
108
+ "RBLNMistralModelConfig",
91
109
  "RBLNOPTForCausalLM",
92
110
  "RBLNOPTForCausalLMConfig",
111
+ "RBLNOPTModel",
112
+ "RBLNOPTModelConfig",
93
113
  "RBLNPhiForCausalLM",
94
114
  "RBLNPhiForCausalLMConfig",
115
+ "RBLNPixtralVisionModelConfig",
116
+ "RBLNPixtralVisionModel",
117
+ "RBLNPhiModel",
118
+ "RBLNPhiModelConfig",
95
119
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
96
120
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
97
121
  "RBLNQwen2_5_VLForConditionalGeneration",
98
122
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
123
+ "RBLNQwen2Model",
124
+ "RBLNQwen2ModelConfig",
99
125
  "RBLNQwen2ForCausalLM",
100
126
  "RBLNQwen2ForCausalLMConfig",
127
+ "RBLNQwen3ForCausalLM",
128
+ "RBLNQwen3ForCausalLMConfig",
129
+ "RBLNQwen3Model",
130
+ "RBLNQwen3ModelConfig",
101
131
  "RBLNResNetForImageClassification",
102
132
  "RBLNResNetForImageClassificationConfig",
103
133
  "RBLNRobertaForMaskedLM",
@@ -166,6 +196,10 @@ if TYPE_CHECKING:
166
196
  RBLNCLIPVisionModelConfig,
167
197
  RBLNCLIPVisionModelWithProjection,
168
198
  RBLNCLIPVisionModelWithProjectionConfig,
199
+ RBLNColPaliForRetrieval,
200
+ RBLNColPaliForRetrievalConfig,
201
+ RBLNDecoderOnlyModel,
202
+ RBLNDecoderOnlyModelConfig,
169
203
  RBLNDecoderOnlyModelForCausalLM,
170
204
  RBLNDecoderOnlyModelForCausalLMConfig,
171
205
  RBLNDistilBertForQuestionAnswering,
@@ -180,30 +214,56 @@ if TYPE_CHECKING:
180
214
  RBLNGemma3ForConditionalGenerationConfig,
181
215
  RBLNGemmaForCausalLM,
182
216
  RBLNGemmaForCausalLMConfig,
217
+ RBLNGemmaModel,
218
+ RBLNGemmaModelConfig,
183
219
  RBLNGPT2LMHeadModel,
184
220
  RBLNGPT2LMHeadModelConfig,
221
+ RBLNGPT2Model,
222
+ RBLNGPT2ModelConfig,
185
223
  RBLNIdefics3ForConditionalGeneration,
186
224
  RBLNIdefics3ForConditionalGenerationConfig,
187
225
  RBLNIdefics3VisionTransformer,
188
226
  RBLNIdefics3VisionTransformerConfig,
189
227
  RBLNLlamaForCausalLM,
190
228
  RBLNLlamaForCausalLMConfig,
229
+ RBLNLlamaModel,
230
+ RBLNLlamaModelConfig,
231
+ RBLNLlavaForConditionalGeneration,
232
+ RBLNLlavaForConditionalGenerationConfig,
191
233
  RBLNLlavaNextForConditionalGeneration,
192
234
  RBLNLlavaNextForConditionalGenerationConfig,
193
235
  RBLNMidmLMHeadModel,
194
236
  RBLNMidmLMHeadModelConfig,
195
237
  RBLNMistralForCausalLM,
196
238
  RBLNMistralForCausalLMConfig,
239
+ RBLNMistralModel,
240
+ RBLNMistralModelConfig,
197
241
  RBLNOPTForCausalLM,
198
242
  RBLNOPTForCausalLMConfig,
243
+ RBLNOPTModel,
244
+ RBLNOPTModelConfig,
245
+ RBLNPegasusForConditionalGeneration,
246
+ RBLNPegasusForConditionalGenerationConfig,
247
+ RBLNPegasusModel,
248
+ RBLNPegasusModelConfig,
199
249
  RBLNPhiForCausalLM,
200
250
  RBLNPhiForCausalLMConfig,
251
+ RBLNPhiModel,
252
+ RBLNPhiModelConfig,
253
+ RBLNPixtralVisionModel,
254
+ RBLNPixtralVisionModelConfig,
201
255
  RBLNQwen2_5_VisionTransformerPretrainedModel,
202
256
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
203
257
  RBLNQwen2_5_VLForConditionalGeneration,
204
258
  RBLNQwen2_5_VLForConditionalGenerationConfig,
205
259
  RBLNQwen2ForCausalLM,
206
260
  RBLNQwen2ForCausalLMConfig,
261
+ RBLNQwen2Model,
262
+ RBLNQwen2ModelConfig,
263
+ RBLNQwen3ForCausalLM,
264
+ RBLNQwen3ForCausalLMConfig,
265
+ RBLNQwen3Model,
266
+ RBLNQwen3ModelConfig,
207
267
  RBLNResNetForImageClassification,
208
268
  RBLNResNetForImageClassificationConfig,
209
269
  RBLNRobertaForMaskedLM,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Optional, Tuple, Union
15
+ from typing import Any, List, Optional, Tuple, Union
16
16
 
17
17
  from ..configuration_utils import RBLNModelConfig
18
18
 
@@ -25,7 +25,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
25
25
  max_seq_len: Optional[int] = None,
26
26
  batch_size: Optional[int] = None,
27
27
  model_input_names: Optional[List[str]] = None,
28
- **kwargs: Dict[str, Any],
28
+ **kwargs: Any,
29
29
  ):
30
30
  """
31
31
  Args:
@@ -52,7 +52,7 @@ class RBLNImageModelConfig(RBLNModelConfig):
52
52
  self,
53
53
  image_size: Optional[Union[int, Tuple[int, int]]] = None,
54
54
  batch_size: Optional[int] = None,
55
- **kwargs: Dict[str, Any],
55
+ **kwargs: Any,
56
56
  ):
57
57
  """
58
58
  Args:
@@ -124,7 +124,7 @@ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
124
124
  batch_size: Optional[int] = None,
125
125
  max_length: Optional[int] = None,
126
126
  num_mel_bins: Optional[int] = None,
127
- **kwargs: Dict[str, Any],
127
+ **kwargs: Any,
128
128
  ):
129
129
  """
130
130
  Args: