optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling.py CHANGED
@@ -34,49 +34,6 @@ if TYPE_CHECKING:
34
34
  logger = get_logger(__name__)
35
35
 
36
36
 
37
- def _get_dtype(
38
- cls,
39
- dtype: Optional[Union[str, torch.dtype, dict]],
40
- config: PretrainedConfig,
41
- ) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
42
- dtype_orig = None
43
-
44
- if dtype is not None:
45
- if isinstance(dtype, str):
46
- if dtype == "auto":
47
- if hasattr(config, "dtype") and config.dtype is not None:
48
- dtype = config.dtype
49
- else:
50
- dtype = torch.get_default_dtype()
51
- elif hasattr(torch, dtype):
52
- dtype = getattr(torch, dtype)
53
- config.dtype = dtype
54
- elif isinstance(dtype, torch.dtype):
55
- config.dtype = dtype
56
- elif isinstance(dtype, dict):
57
- for key, curr_dtype in dtype.items():
58
- if hasattr(config, key):
59
- value = getattr(config, key)
60
- curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
61
- value.dtype = curr_dtype
62
- # main torch dtype for modules that aren't part of any sub-config
63
- dtype = dtype.get("")
64
- dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
65
- config.dtype = dtype
66
- if dtype is None:
67
- dtype = torch.float32
68
- else:
69
- raise ValueError(f"Invalid dtype: {dtype}")
70
-
71
- dtype_orig = cls._set_default_dtype(dtype)
72
- else:
73
- # Use default dtype
74
- default_dtype = torch.get_default_dtype()
75
- config.dtype = default_dtype
76
-
77
- return config, dtype, dtype_orig
78
-
79
-
80
37
  class RBLNModel(RBLNBaseModel):
81
38
  @classmethod
82
39
  def update_kwargs(cls, kwargs):
@@ -97,13 +54,16 @@ class RBLNModel(RBLNBaseModel):
97
54
  pass
98
55
 
99
56
  @classmethod
100
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
57
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
101
58
  # Wrap the model if needed.
102
59
  return model
103
60
 
104
61
  @classmethod
105
62
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
106
- model = cls.wrap_model_if_needed(model, rbln_config)
63
+ if rbln_config._allow_no_compile_cfgs:
64
+ return {}
65
+
66
+ model = cls._wrap_model_if_needed(model, rbln_config)
107
67
  rbln_compile_config = rbln_config.compile_cfgs[0]
108
68
  compiled_model = cls.compile(
109
69
  model,
@@ -113,6 +73,18 @@ class RBLNModel(RBLNBaseModel):
113
73
  )
114
74
  return compiled_model
115
75
 
76
+ @classmethod
77
+ def _update_rbln_config(
78
+ cls,
79
+ preprocessors: Optional[Any],
80
+ model: Optional["PreTrainedModel"] = None,
81
+ model_config: Optional["PretrainedConfig"] = None,
82
+ rbln_config: Optional[RBLNModelConfig] = None,
83
+ ) -> RBLNModelConfig:
84
+ # Default implementation: return config as-is
85
+ # Subclasses should override to set compile_cfgs if needed
86
+ return rbln_config
87
+
116
88
  @classmethod
117
89
  def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
118
90
  return model
@@ -277,6 +249,9 @@ class RBLNModel(RBLNBaseModel):
277
249
  compiled_models: List[rebel.RBLNCompiledModel],
278
250
  rbln_config: RBLNModelConfig,
279
251
  ) -> List[rebel.Runtime]:
252
+ if len(rbln_config.compile_cfgs) == 0:
253
+ return []
254
+
280
255
  if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
281
256
  cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
282
257
 
@@ -15,7 +15,6 @@
15
15
  import importlib
16
16
  import os
17
17
  import shutil
18
- from abc import ABC
19
18
  from pathlib import Path
20
19
  from tempfile import TemporaryDirectory
21
20
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@@ -39,7 +38,7 @@ if TYPE_CHECKING:
39
38
  logger = get_logger(__name__)
40
39
 
41
40
 
42
- class PreTrainedModel(ABC): # noqa: F811
41
+ class PreTrainedModel: # noqa: F811
43
42
  pass
44
43
 
45
44
 
@@ -63,7 +62,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
63
62
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
64
63
  subfolder: str = "",
65
64
  rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
66
- rbln_submodules: List["RBLNBaseModel"] = [],
65
+ rbln_submodules: Optional[List["RBLNBaseModel"]] = None,
67
66
  **kwargs,
68
67
  ):
69
68
  self.model = models
@@ -71,7 +70,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
71
70
  self.rbln_config = rbln_config
72
71
  if not rbln_config.is_frozen():
73
72
  raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
74
-
75
73
  self.compiled_models = rbln_compiled_models
76
74
 
77
75
  # Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
@@ -92,7 +90,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
92
90
 
93
91
  self.device = torch.device("cpu")
94
92
  self.training = False
95
- self.dtype = rbln_config.torch_dtype
93
+ self.dtype = rbln_config.dtype
96
94
 
97
95
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
98
96
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -107,6 +105,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
107
105
  self.model_save_dir = model_save_dir
108
106
  self.subfolder = subfolder
109
107
 
108
+ if rbln_submodules is None:
109
+ rbln_submodules = []
110
110
  self.rbln_submodules = rbln_submodules
111
111
  self.__post_init__(**kwargs)
112
112
 
@@ -182,7 +182,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
182
182
  # passed from compile function
183
183
  rbln_config: Optional[RBLNModelConfig] = None,
184
184
  rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
185
- rbln_submodules: List["RBLNBaseModel"] = [],
185
+ rbln_submodules: Optional[List["RBLNBaseModel"]] = None,
186
186
  **kwargs,
187
187
  ) -> "RBLNBaseModel":
188
188
  if rbln_compiled_models is None:
@@ -218,12 +218,11 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
218
218
  )
219
219
 
220
220
  if len(cls._rbln_submodules) > 0:
221
- rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
222
- else:
221
+ if rbln_submodules is None:
222
+ rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
223
+ elif rbln_submodules is None:
223
224
  rbln_submodules = []
224
225
 
225
- rbln_config.freeze()
226
-
227
226
  if config is None:
228
227
  if cls.hf_library_name == "transformers":
229
228
  config = AutoConfig.from_pretrained(
@@ -280,9 +279,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
280
279
  config: "PretrainedConfig",
281
280
  model_save_dir: Union[Path, str],
282
281
  subfolder: Union[Path, str],
283
- rbln_submodules: List["RBLNBaseModel"] = [],
282
+ rbln_submodules: Optional[List["RBLNBaseModel"]] = None,
284
283
  **kwargs,
285
284
  ):
285
+ if rbln_submodules is None:
286
+ rbln_submodules = []
287
+
286
288
  if isinstance(model_save_dir, str):
287
289
  model_save_dir = Path(model_save_dir)
288
290
 
@@ -309,6 +311,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
309
311
  )
310
312
  raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
311
313
 
314
+ rbln_config.freeze()
315
+
312
316
  return cls(
313
317
  models,
314
318
  config,
@@ -447,15 +451,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
447
451
  model_config: "PretrainedConfig",
448
452
  rbln_config: RBLNModelConfig,
449
453
  ) -> RBLNModelConfig:
450
- rbln_config.torch_dtype = model.dtype
451
- if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
454
+ rbln_config.dtype = model.dtype
455
+ if not cls._supports_non_fp32 and rbln_config.dtype != torch.float32:
452
456
  raise NotImplementedError(
453
457
  f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
454
458
  )
455
459
  rbln_config = cls._update_rbln_config(
456
460
  preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
457
461
  )
458
- rbln_config.freeze()
462
+
459
463
  if rbln_config.rbln_model_cls_name != cls.__name__:
460
464
  raise NameError(
461
465
  f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
@@ -16,4 +16,5 @@ from .attn import *
16
16
  from .flash_attn import *
17
17
  from .kv_cache_update import *
18
18
  from .linear import linear
19
+ from .moe import *
19
20
  from .sliding_window_attn import *
optimum/rbln/ops/attn.py CHANGED
@@ -205,6 +205,7 @@ def paged_causal_attn_decode(
205
205
  block_table: Tensor,
206
206
  block_size: int,
207
207
  mask: Optional[Tensor] = None,
208
+ s_aux: Optional[Tensor] = None,
208
209
  ) -> Tensor:
209
210
  """Defines the computation pattern for fused attention with KV cache updates.
210
211
 
@@ -228,6 +229,7 @@ def paged_causal_attn_decode(
228
229
  - block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
229
230
  - block_size: [] - Number of tokens per block
230
231
  - mask: [batch=1, max_seq_len] - attention mask when use position_ids
232
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
231
233
 
232
234
  Returns:
233
235
  Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
@@ -247,6 +249,7 @@ def paged_causal_attn_decode_fake(
247
249
  block_table: Tensor,
248
250
  block_size: int,
249
251
  mask: Optional[Tensor] = None,
252
+ s_aux: Optional[Tensor] = None,
250
253
  ) -> Tensor:
251
254
  return torch.empty_like(q)
252
255
 
@@ -267,6 +270,7 @@ def paged_causal_attn_prefill(
267
270
  block_size: int,
268
271
  is_bidirectional: bool,
269
272
  mask: Optional[Tensor] = None,
273
+ s_aux: Optional[Tensor] = None,
270
274
  ) -> Tensor:
271
275
  """Defines the computation pattern for prefill phase attention with KV cache updates.
272
276
 
@@ -290,6 +294,7 @@ def paged_causal_attn_prefill(
290
294
  - block_size: [] - Number of tokens per block
291
295
  - is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
292
296
  - mask: [batch=1, max_seq_len] - attention mask when use position_ids
297
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
293
298
 
294
299
  Returns:
295
300
  Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
@@ -310,6 +315,7 @@ def paged_causal_attn_prefill_fake(
310
315
  block_size: int,
311
316
  is_bidirectional: bool,
312
317
  mask: Optional[Tensor] = None,
318
+ s_aux: Optional[Tensor] = None,
313
319
  ) -> Tensor:
314
320
  return torch.empty_like(q)
315
321
 
@@ -331,6 +337,7 @@ def paged_causal_attn_decode_kv_fp8(
331
337
  k_scale: Tensor,
332
338
  v_scale: Tensor,
333
339
  mask: Optional[Tensor] = None,
340
+ s_aux: Optional[Tensor] = None,
334
341
  ) -> Tensor:
335
342
  return torch.empty_like(q)
336
343
 
@@ -349,6 +356,7 @@ def paged_causal_attn_decode_kv_fp8_fake(
349
356
  k_scale: Tensor,
350
357
  v_scale: Tensor,
351
358
  mask: Optional[Tensor] = None,
359
+ s_aux: Optional[Tensor] = None,
352
360
  ) -> Tensor:
353
361
  return torch.empty_like(q)
354
362
 
@@ -371,6 +379,7 @@ def paged_causal_attn_prefill_kv_fp8(
371
379
  k_scale: Tensor,
372
380
  v_scale: Tensor,
373
381
  mask: Optional[Tensor] = None,
382
+ s_aux: Optional[Tensor] = None,
374
383
  ) -> Tensor:
375
384
  return torch.empty_like(q)
376
385
 
@@ -390,6 +399,7 @@ def paged_causal_attn_prefill_kv_fp8_fake(
390
399
  k_scale: Tensor,
391
400
  v_scale: Tensor,
392
401
  mask: Optional[Tensor] = None,
402
+ s_aux: Optional[Tensor] = None,
393
403
  ) -> Tensor:
394
404
  return torch.empty_like(q)
395
405
 
@@ -198,6 +198,7 @@ def paged_flash_causal_attn_decode(
198
198
  block_size: int,
199
199
  partition: int,
200
200
  mask: Optional[Tensor] = None,
201
+ s_aux: Optional[Tensor] = None,
201
202
  ) -> Tensor:
202
203
  """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
203
204
 
@@ -219,6 +220,7 @@ def paged_flash_causal_attn_decode_fake(
219
220
  block_size: int,
220
221
  partition: int,
221
222
  mask: Optional[Tensor] = None,
223
+ s_aux: Optional[Tensor] = None,
222
224
  ) -> Tensor:
223
225
  return torch.empty_like(q)
224
226
 
@@ -241,6 +243,7 @@ def paged_flash_causal_attn_decode_kv_fp8(
241
243
  k_scale: Tensor,
242
244
  v_scale: Tensor,
243
245
  mask: Optional[Tensor] = None,
246
+ s_aux: Optional[Tensor] = None,
244
247
  ) -> Tensor:
245
248
  return torch.empty_like(q)
246
249
 
@@ -260,6 +263,7 @@ def paged_flash_causal_attn_decode_kv_fp8_fake(
260
263
  k_scale: Tensor,
261
264
  v_scale: Tensor,
262
265
  mask: Optional[Tensor] = None,
266
+ s_aux: Optional[Tensor] = None,
263
267
  ) -> Tensor:
264
268
  return torch.empty_like(q)
265
269
 
@@ -281,6 +285,7 @@ def paged_flash_causal_attn_prefill(
281
285
  partition: int,
282
286
  is_bidirectional: bool,
283
287
  mask: Optional[Tensor] = None,
288
+ s_aux: Optional[Tensor] = None,
284
289
  ) -> Tensor:
285
290
  """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
286
291
 
@@ -303,6 +308,7 @@ def paged_flash_causal_attn_prefill_fake(
303
308
  partition: int,
304
309
  is_bidirectional: bool,
305
310
  mask: Optional[Tensor] = None,
311
+ s_aux: Optional[Tensor] = None,
306
312
  ) -> Tensor:
307
313
  return torch.empty_like(q)
308
314
 
@@ -326,6 +332,7 @@ def paged_flash_causal_attn_prefill_kv_fp8(
326
332
  k_scale: Tensor,
327
333
  v_scale: Tensor,
328
334
  mask: Optional[Tensor] = None,
335
+ s_aux: Optional[Tensor] = None,
329
336
  ) -> Tensor:
330
337
  return torch.empty_like(q)
331
338
 
@@ -346,5 +353,6 @@ def paged_flash_causal_attn_prefill_kv_fp8_fake(
346
353
  k_scale: Tensor,
347
354
  v_scale: Tensor,
348
355
  mask: Optional[Tensor] = None,
356
+ s_aux: Optional[Tensor] = None,
349
357
  ) -> Tensor:
350
358
  return torch.empty_like(q)
@@ -0,0 +1,180 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op(
22
+ "rbln_custom_ops::custom_moe_glu",
23
+ mutates_args=(),
24
+ )
25
+ def custom_moe_glu(
26
+ hidden_states: Tensor,
27
+ gate_proj_weight: Tensor,
28
+ up_proj_weight: Tensor,
29
+ down_proj_weight: Tensor,
30
+ router_logits: Tensor,
31
+ topk: int,
32
+ norm_topk_prob: bool,
33
+ gate_proj_bias: Optional[Tensor] = None,
34
+ up_proj_bias: Optional[Tensor] = None,
35
+ down_proj_bias: Optional[Tensor] = None,
36
+ ) -> Tensor:
37
+ """
38
+ Customized MoE GLU operation.
39
+
40
+ Expected tensor shapes:
41
+ - hidden_states: [batch*seq_len, hidden_size]
42
+ - gate_proj_weight: [num_experts, hidden_size, intermediate_size]
43
+ - up_proj_weight: [num_experts, hidden_size, intermediate_size]
44
+ - down_proj_weight: [num_experts, intermediate_size, hidden_size]
45
+ - router_logits: [batch*seq_len, num_experts]
46
+ - topk: top k experts to select
47
+ - norm_topk_prob: whether to normalize the top k routing weights with softmax
48
+ - gate_proj_bias: [num_experts, intermediate_size]
49
+ - up_proj_bias: [num_experts, intermediate_size]
50
+ - down_proj_bias: [num_experts, hidden_size]
51
+
52
+ Returns:
53
+ Tensor: [batch * seq_len, hidden_size]
54
+ """
55
+
56
+ return torch.empty_like(hidden_states)
57
+
58
+
59
+ @custom_moe_glu.register_fake
60
+ def custom_moe_glu_fake(
61
+ hidden_states: Tensor,
62
+ gate_proj_weight: Tensor,
63
+ up_proj_weight: Tensor,
64
+ down_proj_weight: Tensor,
65
+ router_logits: Tensor,
66
+ topk: int,
67
+ norm_topk_prob: bool,
68
+ gate_proj_bias: Optional[Tensor] = None,
69
+ up_proj_bias: Optional[Tensor] = None,
70
+ down_proj_bias: Optional[Tensor] = None,
71
+ ) -> Tensor:
72
+ return torch.empty_like(hidden_states)
73
+
74
+
75
+ @torch.library.custom_op(
76
+ "rbln_custom_ops::custom_moe_ff",
77
+ mutates_args=(),
78
+ )
79
+ def custom_moe_ff(
80
+ hidden_states: Tensor,
81
+ gate_proj_weight: Tensor,
82
+ down_proj_weight: Tensor,
83
+ masked_routing_weight: Tensor,
84
+ gate_proj_bias: Optional[Tensor] = None,
85
+ down_proj_bias: Optional[Tensor] = None,
86
+ ) -> Tensor:
87
+ """
88
+ Customized MoE FF operation.
89
+
90
+ Expected tensor shapes:
91
+ - hidden_states: [batch * seq_len, hidden_size]
92
+ - gate_proj_weight: [hidden_size, num_experts * intermediate_size]
93
+ - down_proj_weight: [num_experts * intermediate_size, hidden_size]
94
+ - masked_routing_weight: [batch * seq_len, num_experts]
95
+ - gate_proj_bias: [num_experts * intermediate_size]
96
+ - down_proj_bias: [hidden_size]
97
+
98
+ Returns:
99
+ Tensor: [batch * seq_len, hidden_size]
100
+ """
101
+ return torch.empty_like(hidden_states)
102
+
103
+
104
+ @custom_moe_ff.register_fake
105
+ def custom_moe_ff_fake(
106
+ hidden_states: Tensor,
107
+ gate_proj_weight: Tensor,
108
+ down_proj_weight: Tensor,
109
+ masked_routing_weight: Tensor,
110
+ gate_proj_bias: Optional[Tensor] = None,
111
+ down_proj_bias: Optional[Tensor] = None,
112
+ ) -> Tensor:
113
+ return torch.empty_like(hidden_states)
114
+
115
+
116
+ @torch.library.custom_op(
117
+ "rbln_custom_ops::custom_moe_glu_mxfp4",
118
+ mutates_args=(),
119
+ )
120
+ def custom_moe_glu_mxfp4(
121
+ hidden_states: Tensor,
122
+ gate_proj_blocks: Tensor,
123
+ gate_proj_scales: Tensor,
124
+ gate_proj_bias: Tensor,
125
+ up_proj_blocks: Tensor,
126
+ up_proj_scales: Tensor,
127
+ up_proj_bias: Tensor,
128
+ down_proj_blocks: Tensor,
129
+ down_proj_scales: Tensor,
130
+ down_proj_bias: Tensor,
131
+ router_logits: Tensor,
132
+ alpha: Tensor,
133
+ limit: Tensor,
134
+ k: int,
135
+ post_norm: bool,
136
+ ) -> Tensor:
137
+ """
138
+ Customized MoE GLU operation.
139
+
140
+ Expected tensor shapes:
141
+ - hidden_states: [batch*seq_len, hidden_size]
142
+ - gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
143
+ - gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
144
+ - gate_proj_bias: [num_experts, intermediate_size]
145
+ - up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
146
+ - up_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
147
+ - up_proj_bias: [num_experts, intermediate_size]
148
+ - down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2]
149
+ - down_proj_scales: [num_experts, hidden_size, intermediate_size // 32]
150
+ - masked_routing_weight: [batch * seq_len, num_experts]
151
+ - expert_select_count: [num_experts]
152
+ - alpha: []
153
+ - limit: []
154
+
155
+ Returns:
156
+ Tensor: [batch * seq_len, hidden_size]
157
+ """
158
+
159
+ return torch.empty_like(hidden_states)
160
+
161
+
162
+ @custom_moe_glu_mxfp4.register_fake
163
+ def custom_moe_glu_mxfp4_fake(
164
+ hidden_states: Tensor,
165
+ gate_proj_blocks: Tensor,
166
+ gate_proj_scales: Tensor,
167
+ gate_proj_bias: Tensor,
168
+ up_proj_blocks: Tensor,
169
+ up_proj_scales: Tensor,
170
+ up_proj_bias: Tensor,
171
+ down_proj_blocks: Tensor,
172
+ down_proj_scales: Tensor,
173
+ down_proj_bias: Tensor,
174
+ router_logits: Tensor,
175
+ alpha: Tensor,
176
+ limit: Tensor,
177
+ k: int,
178
+ post_norm: bool,
179
+ ) -> Tensor:
180
+ return torch.empty_like(hidden_states)
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
17
19
  from torch import Tensor
18
20
 
@@ -33,6 +35,7 @@ def paged_sliding_window_attn_prefill(
33
35
  block_table: Tensor,
34
36
  block_size: int,
35
37
  is_bidirectional: bool,
38
+ s_aux: Optional[Tensor] = None,
36
39
  ) -> Tensor:
37
40
  """Defines the computation pattern for prefill phase attention with KV cache updates.
38
41
 
@@ -53,6 +56,7 @@ def paged_sliding_window_attn_prefill(
53
56
  - cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
54
57
  - scale: [] - Attention scale factor
55
58
  - is_bidirectional: [] - Whether the attention is bidirectional
59
+ - s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
56
60
  Returns:
57
61
  Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
58
62
  """
@@ -72,6 +76,7 @@ def paged_sliding_window_attn_prefill_fake(
72
76
  block_table: Tensor,
73
77
  block_size: int,
74
78
  is_bidirectional: bool,
79
+ s_aux: Optional[Tensor] = None,
75
80
  ) -> Tensor:
76
81
  return torch.empty_like(q)
77
82
 
@@ -91,6 +96,8 @@ def paged_sliding_window_attn_decode(
91
96
  scale: Tensor,
92
97
  block_table: Tensor,
93
98
  block_size: int,
99
+ attn_mask: Tensor,
100
+ s_aux: Optional[Tensor] = None,
94
101
  ) -> Tensor:
95
102
  return torch.empty_like(q)
96
103
 
@@ -107,5 +114,7 @@ def paged_sliding_window_attn_decode_fake(
107
114
  scale: Tensor,
108
115
  block_table: Tensor,
109
116
  block_size: int,
117
+ attn_mask: Tensor,
118
+ s_aux: Optional[Tensor] = None,
110
119
  ) -> Tensor:
111
120
  return torch.empty_like(q)
@@ -78,6 +78,10 @@ _import_structure = {
78
78
  "RBLNExaoneForCausalLMConfig",
79
79
  "RBLNGemmaModel",
80
80
  "RBLNGemmaModelConfig",
81
+ "RBLNGemma2ForCausalLM",
82
+ "RBLNGemma2ForCausalLMConfig",
83
+ "RBLNGemma2Model",
84
+ "RBLNGemma2ModelConfig",
81
85
  "RBLNGemma3ForCausalLM",
82
86
  "RBLNGemma3ForCausalLMConfig",
83
87
  "RBLNGemma3ForConditionalGeneration",
@@ -88,6 +92,8 @@ _import_structure = {
88
92
  "RBLNGPT2LMHeadModelConfig",
89
93
  "RBLNGPT2Model",
90
94
  "RBLNGPT2ModelConfig",
95
+ "RBLNGptOssForCausalLM",
96
+ "RBLNGptOssForCausalLMConfig",
91
97
  "RBLNGroundingDinoDecoder",
92
98
  "RBLNGroundingDinoDecoderConfig",
93
99
  "RBLNGroundingDinoForObjectDetection",
@@ -110,6 +116,10 @@ _import_structure = {
110
116
  "RBLNPegasusForConditionalGenerationConfig",
111
117
  "RBLNPegasusModel",
112
118
  "RBLNPegasusModelConfig",
119
+ "RBLNPaliGemmaForConditionalGeneration",
120
+ "RBLNPaliGemmaForConditionalGenerationConfig",
121
+ "RBLNPaliGemmaModel",
122
+ "RBLNPaliGemmaModelConfig",
113
123
  "RBLNLlavaNextForConditionalGeneration",
114
124
  "RBLNLlavaNextForConditionalGenerationConfig",
115
125
  "RBLNLoRAAdapterConfig",
@@ -134,14 +144,22 @@ _import_structure = {
134
144
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
135
145
  "RBLNQwen2_5_VLForConditionalGeneration",
136
146
  "RBLNQwen2_5_VLForConditionalGenerationConfig",
147
+ "RBLNQwen2_5_VLModel",
148
+ "RBLNQwen2_5_VLModelConfig",
137
149
  "RBLNQwen2VisionTransformerPretrainedModel",
138
150
  "RBLNQwen2VisionTransformerPretrainedModelConfig",
139
151
  "RBLNQwen2VLForConditionalGeneration",
140
152
  "RBLNQwen2VLForConditionalGenerationConfig",
153
+ "RBLNQwen2VLModel",
154
+ "RBLNQwen2VLModelConfig",
141
155
  "RBLNQwen2Model",
142
156
  "RBLNQwen2ModelConfig",
143
157
  "RBLNQwen2ForCausalLM",
144
158
  "RBLNQwen2ForCausalLMConfig",
159
+ "RBLNQwen2MoeForCausalLM",
160
+ "RBLNQwen2MoeForCausalLMConfig",
161
+ "RBLNQwen3MoeForCausalLM",
162
+ "RBLNQwen3MoeForCausalLMConfig",
145
163
  "RBLNQwen3ForCausalLM",
146
164
  "RBLNQwen3ForCausalLMConfig",
147
165
  "RBLNQwen3Model",
@@ -234,6 +252,10 @@ if TYPE_CHECKING:
234
252
  RBLNDPTForDepthEstimationConfig,
235
253
  RBLNExaoneForCausalLM,
236
254
  RBLNExaoneForCausalLMConfig,
255
+ RBLNGemma2ForCausalLM,
256
+ RBLNGemma2ForCausalLMConfig,
257
+ RBLNGemma2Model,
258
+ RBLNGemma2ModelConfig,
237
259
  RBLNGemma3ForCausalLM,
238
260
  RBLNGemma3ForCausalLMConfig,
239
261
  RBLNGemma3ForConditionalGeneration,
@@ -246,6 +268,8 @@ if TYPE_CHECKING:
246
268
  RBLNGPT2LMHeadModelConfig,
247
269
  RBLNGPT2Model,
248
270
  RBLNGPT2ModelConfig,
271
+ RBLNGptOssForCausalLM,
272
+ RBLNGptOssForCausalLMConfig,
249
273
  RBLNGroundingDinoDecoder,
250
274
  RBLNGroundingDinoDecoderConfig,
251
275
  RBLNGroundingDinoEncoder,
@@ -276,6 +300,10 @@ if TYPE_CHECKING:
276
300
  RBLNOPTForCausalLMConfig,
277
301
  RBLNOPTModel,
278
302
  RBLNOPTModelConfig,
303
+ RBLNPaliGemmaForConditionalGeneration,
304
+ RBLNPaliGemmaForConditionalGenerationConfig,
305
+ RBLNPaliGemmaModel,
306
+ RBLNPaliGemmaModelConfig,
279
307
  RBLNPegasusForConditionalGeneration,
280
308
  RBLNPegasusForConditionalGenerationConfig,
281
309
  RBLNPegasusModel,
@@ -290,18 +318,26 @@ if TYPE_CHECKING:
290
318
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
291
319
  RBLNQwen2_5_VLForConditionalGeneration,
292
320
  RBLNQwen2_5_VLForConditionalGenerationConfig,
321
+ RBLNQwen2_5_VLModel,
322
+ RBLNQwen2_5_VLModelConfig,
293
323
  RBLNQwen2ForCausalLM,
294
324
  RBLNQwen2ForCausalLMConfig,
295
325
  RBLNQwen2Model,
296
326
  RBLNQwen2ModelConfig,
327
+ RBLNQwen2MoeForCausalLM,
328
+ RBLNQwen2MoeForCausalLMConfig,
297
329
  RBLNQwen2VisionTransformerPretrainedModel,
298
330
  RBLNQwen2VisionTransformerPretrainedModelConfig,
299
331
  RBLNQwen2VLForConditionalGeneration,
300
332
  RBLNQwen2VLForConditionalGenerationConfig,
333
+ RBLNQwen2VLModel,
334
+ RBLNQwen2VLModelConfig,
301
335
  RBLNQwen3ForCausalLM,
302
336
  RBLNQwen3ForCausalLMConfig,
303
337
  RBLNQwen3Model,
304
338
  RBLNQwen3ModelConfig,
339
+ RBLNQwen3MoeForCausalLM,
340
+ RBLNQwen3MoeForCausalLMConfig,
305
341
  RBLNResNetForImageClassification,
306
342
  RBLNResNetForImageClassificationConfig,
307
343
  RBLNRobertaForMaskedLM,