optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling.py
CHANGED
|
@@ -34,49 +34,6 @@ if TYPE_CHECKING:
|
|
|
34
34
|
logger = get_logger(__name__)
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
def _get_dtype(
|
|
38
|
-
cls,
|
|
39
|
-
dtype: Optional[Union[str, torch.dtype, dict]],
|
|
40
|
-
config: PretrainedConfig,
|
|
41
|
-
) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
|
|
42
|
-
dtype_orig = None
|
|
43
|
-
|
|
44
|
-
if dtype is not None:
|
|
45
|
-
if isinstance(dtype, str):
|
|
46
|
-
if dtype == "auto":
|
|
47
|
-
if hasattr(config, "dtype") and config.dtype is not None:
|
|
48
|
-
dtype = config.dtype
|
|
49
|
-
else:
|
|
50
|
-
dtype = torch.get_default_dtype()
|
|
51
|
-
elif hasattr(torch, dtype):
|
|
52
|
-
dtype = getattr(torch, dtype)
|
|
53
|
-
config.dtype = dtype
|
|
54
|
-
elif isinstance(dtype, torch.dtype):
|
|
55
|
-
config.dtype = dtype
|
|
56
|
-
elif isinstance(dtype, dict):
|
|
57
|
-
for key, curr_dtype in dtype.items():
|
|
58
|
-
if hasattr(config, key):
|
|
59
|
-
value = getattr(config, key)
|
|
60
|
-
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
|
|
61
|
-
value.dtype = curr_dtype
|
|
62
|
-
# main torch dtype for modules that aren't part of any sub-config
|
|
63
|
-
dtype = dtype.get("")
|
|
64
|
-
dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
|
|
65
|
-
config.dtype = dtype
|
|
66
|
-
if dtype is None:
|
|
67
|
-
dtype = torch.float32
|
|
68
|
-
else:
|
|
69
|
-
raise ValueError(f"Invalid dtype: {dtype}")
|
|
70
|
-
|
|
71
|
-
dtype_orig = cls._set_default_dtype(dtype)
|
|
72
|
-
else:
|
|
73
|
-
# Use default dtype
|
|
74
|
-
default_dtype = torch.get_default_dtype()
|
|
75
|
-
config.dtype = default_dtype
|
|
76
|
-
|
|
77
|
-
return config, dtype, dtype_orig
|
|
78
|
-
|
|
79
|
-
|
|
80
37
|
class RBLNModel(RBLNBaseModel):
|
|
81
38
|
@classmethod
|
|
82
39
|
def update_kwargs(cls, kwargs):
|
|
@@ -97,13 +54,16 @@ class RBLNModel(RBLNBaseModel):
|
|
|
97
54
|
pass
|
|
98
55
|
|
|
99
56
|
@classmethod
|
|
100
|
-
def
|
|
57
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
101
58
|
# Wrap the model if needed.
|
|
102
59
|
return model
|
|
103
60
|
|
|
104
61
|
@classmethod
|
|
105
62
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
106
|
-
|
|
63
|
+
if rbln_config._allow_no_compile_cfgs:
|
|
64
|
+
return {}
|
|
65
|
+
|
|
66
|
+
model = cls._wrap_model_if_needed(model, rbln_config)
|
|
107
67
|
rbln_compile_config = rbln_config.compile_cfgs[0]
|
|
108
68
|
compiled_model = cls.compile(
|
|
109
69
|
model,
|
|
@@ -113,6 +73,18 @@ class RBLNModel(RBLNBaseModel):
|
|
|
113
73
|
)
|
|
114
74
|
return compiled_model
|
|
115
75
|
|
|
76
|
+
@classmethod
|
|
77
|
+
def _update_rbln_config(
|
|
78
|
+
cls,
|
|
79
|
+
preprocessors: Optional[Any],
|
|
80
|
+
model: Optional["PreTrainedModel"] = None,
|
|
81
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
82
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
|
83
|
+
) -> RBLNModelConfig:
|
|
84
|
+
# Default implementation: return config as-is
|
|
85
|
+
# Subclasses should override to set compile_cfgs if needed
|
|
86
|
+
return rbln_config
|
|
87
|
+
|
|
116
88
|
@classmethod
|
|
117
89
|
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
118
90
|
return model
|
|
@@ -277,6 +249,9 @@ class RBLNModel(RBLNBaseModel):
|
|
|
277
249
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
278
250
|
rbln_config: RBLNModelConfig,
|
|
279
251
|
) -> List[rebel.Runtime]:
|
|
252
|
+
if len(rbln_config.compile_cfgs) == 0:
|
|
253
|
+
return []
|
|
254
|
+
|
|
280
255
|
if DEFAULT_COMPILED_MODEL_NAME not in rbln_config.device_map:
|
|
281
256
|
cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
|
|
282
257
|
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
import importlib
|
|
16
16
|
import os
|
|
17
17
|
import shutil
|
|
18
|
-
from abc import ABC
|
|
19
18
|
from pathlib import Path
|
|
20
19
|
from tempfile import TemporaryDirectory
|
|
21
20
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
|
@@ -39,7 +38,7 @@ if TYPE_CHECKING:
|
|
|
39
38
|
logger = get_logger(__name__)
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
class PreTrainedModel
|
|
41
|
+
class PreTrainedModel: # noqa: F811
|
|
43
42
|
pass
|
|
44
43
|
|
|
45
44
|
|
|
@@ -63,7 +62,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
63
62
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
64
63
|
subfolder: str = "",
|
|
65
64
|
rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
|
|
66
|
-
rbln_submodules: List["RBLNBaseModel"] =
|
|
65
|
+
rbln_submodules: Optional[List["RBLNBaseModel"]] = None,
|
|
67
66
|
**kwargs,
|
|
68
67
|
):
|
|
69
68
|
self.model = models
|
|
@@ -71,7 +70,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
71
70
|
self.rbln_config = rbln_config
|
|
72
71
|
if not rbln_config.is_frozen():
|
|
73
72
|
raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
|
|
74
|
-
|
|
75
73
|
self.compiled_models = rbln_compiled_models
|
|
76
74
|
|
|
77
75
|
# Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
|
|
@@ -92,7 +90,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
92
90
|
|
|
93
91
|
self.device = torch.device("cpu")
|
|
94
92
|
self.training = False
|
|
95
|
-
self.dtype = rbln_config.
|
|
93
|
+
self.dtype = rbln_config.dtype
|
|
96
94
|
|
|
97
95
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
98
96
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -107,6 +105,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
107
105
|
self.model_save_dir = model_save_dir
|
|
108
106
|
self.subfolder = subfolder
|
|
109
107
|
|
|
108
|
+
if rbln_submodules is None:
|
|
109
|
+
rbln_submodules = []
|
|
110
110
|
self.rbln_submodules = rbln_submodules
|
|
111
111
|
self.__post_init__(**kwargs)
|
|
112
112
|
|
|
@@ -182,7 +182,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
182
182
|
# passed from compile function
|
|
183
183
|
rbln_config: Optional[RBLNModelConfig] = None,
|
|
184
184
|
rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
|
|
185
|
-
rbln_submodules: List["RBLNBaseModel"] =
|
|
185
|
+
rbln_submodules: Optional[List["RBLNBaseModel"]] = None,
|
|
186
186
|
**kwargs,
|
|
187
187
|
) -> "RBLNBaseModel":
|
|
188
188
|
if rbln_compiled_models is None:
|
|
@@ -218,12 +218,11 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
218
218
|
)
|
|
219
219
|
|
|
220
220
|
if len(cls._rbln_submodules) > 0:
|
|
221
|
-
rbln_submodules
|
|
222
|
-
|
|
221
|
+
if rbln_submodules is None:
|
|
222
|
+
rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
|
|
223
|
+
elif rbln_submodules is None:
|
|
223
224
|
rbln_submodules = []
|
|
224
225
|
|
|
225
|
-
rbln_config.freeze()
|
|
226
|
-
|
|
227
226
|
if config is None:
|
|
228
227
|
if cls.hf_library_name == "transformers":
|
|
229
228
|
config = AutoConfig.from_pretrained(
|
|
@@ -280,9 +279,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
280
279
|
config: "PretrainedConfig",
|
|
281
280
|
model_save_dir: Union[Path, str],
|
|
282
281
|
subfolder: Union[Path, str],
|
|
283
|
-
rbln_submodules: List["RBLNBaseModel"] =
|
|
282
|
+
rbln_submodules: Optional[List["RBLNBaseModel"]] = None,
|
|
284
283
|
**kwargs,
|
|
285
284
|
):
|
|
285
|
+
if rbln_submodules is None:
|
|
286
|
+
rbln_submodules = []
|
|
287
|
+
|
|
286
288
|
if isinstance(model_save_dir, str):
|
|
287
289
|
model_save_dir = Path(model_save_dir)
|
|
288
290
|
|
|
@@ -309,6 +311,8 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
309
311
|
)
|
|
310
312
|
raise rebel.core.exception.RBLNRuntimeError(error_msg) from e
|
|
311
313
|
|
|
314
|
+
rbln_config.freeze()
|
|
315
|
+
|
|
312
316
|
return cls(
|
|
313
317
|
models,
|
|
314
318
|
config,
|
|
@@ -447,15 +451,15 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
447
451
|
model_config: "PretrainedConfig",
|
|
448
452
|
rbln_config: RBLNModelConfig,
|
|
449
453
|
) -> RBLNModelConfig:
|
|
450
|
-
rbln_config.
|
|
451
|
-
if not cls._supports_non_fp32 and rbln_config.
|
|
454
|
+
rbln_config.dtype = model.dtype
|
|
455
|
+
if not cls._supports_non_fp32 and rbln_config.dtype != torch.float32:
|
|
452
456
|
raise NotImplementedError(
|
|
453
457
|
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
454
458
|
)
|
|
455
459
|
rbln_config = cls._update_rbln_config(
|
|
456
460
|
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
457
461
|
)
|
|
458
|
-
|
|
462
|
+
|
|
459
463
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
460
464
|
raise NameError(
|
|
461
465
|
f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
|
optimum/rbln/ops/__init__.py
CHANGED
optimum/rbln/ops/attn.py
CHANGED
|
@@ -205,6 +205,7 @@ def paged_causal_attn_decode(
|
|
|
205
205
|
block_table: Tensor,
|
|
206
206
|
block_size: int,
|
|
207
207
|
mask: Optional[Tensor] = None,
|
|
208
|
+
s_aux: Optional[Tensor] = None,
|
|
208
209
|
) -> Tensor:
|
|
209
210
|
"""Defines the computation pattern for fused attention with KV cache updates.
|
|
210
211
|
|
|
@@ -228,6 +229,7 @@ def paged_causal_attn_decode(
|
|
|
228
229
|
- block_table: [batch_size, max_seq_len // block_size] - Block indices for KV cache management
|
|
229
230
|
- block_size: [] - Number of tokens per block
|
|
230
231
|
- mask: [batch=1, max_seq_len] - attention mask when use position_ids
|
|
232
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
231
233
|
|
|
232
234
|
Returns:
|
|
233
235
|
Tensor: attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
|
@@ -247,6 +249,7 @@ def paged_causal_attn_decode_fake(
|
|
|
247
249
|
block_table: Tensor,
|
|
248
250
|
block_size: int,
|
|
249
251
|
mask: Optional[Tensor] = None,
|
|
252
|
+
s_aux: Optional[Tensor] = None,
|
|
250
253
|
) -> Tensor:
|
|
251
254
|
return torch.empty_like(q)
|
|
252
255
|
|
|
@@ -267,6 +270,7 @@ def paged_causal_attn_prefill(
|
|
|
267
270
|
block_size: int,
|
|
268
271
|
is_bidirectional: bool,
|
|
269
272
|
mask: Optional[Tensor] = None,
|
|
273
|
+
s_aux: Optional[Tensor] = None,
|
|
270
274
|
) -> Tensor:
|
|
271
275
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
272
276
|
|
|
@@ -290,6 +294,7 @@ def paged_causal_attn_prefill(
|
|
|
290
294
|
- block_size: [] - Number of tokens per block
|
|
291
295
|
- is_bidirectional: [] - Whether the attention is bidirectional at current sequence position
|
|
292
296
|
- mask: [batch=1, max_seq_len] - attention mask when use position_ids
|
|
297
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
293
298
|
|
|
294
299
|
Returns:
|
|
295
300
|
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
@@ -310,6 +315,7 @@ def paged_causal_attn_prefill_fake(
|
|
|
310
315
|
block_size: int,
|
|
311
316
|
is_bidirectional: bool,
|
|
312
317
|
mask: Optional[Tensor] = None,
|
|
318
|
+
s_aux: Optional[Tensor] = None,
|
|
313
319
|
) -> Tensor:
|
|
314
320
|
return torch.empty_like(q)
|
|
315
321
|
|
|
@@ -331,6 +337,7 @@ def paged_causal_attn_decode_kv_fp8(
|
|
|
331
337
|
k_scale: Tensor,
|
|
332
338
|
v_scale: Tensor,
|
|
333
339
|
mask: Optional[Tensor] = None,
|
|
340
|
+
s_aux: Optional[Tensor] = None,
|
|
334
341
|
) -> Tensor:
|
|
335
342
|
return torch.empty_like(q)
|
|
336
343
|
|
|
@@ -349,6 +356,7 @@ def paged_causal_attn_decode_kv_fp8_fake(
|
|
|
349
356
|
k_scale: Tensor,
|
|
350
357
|
v_scale: Tensor,
|
|
351
358
|
mask: Optional[Tensor] = None,
|
|
359
|
+
s_aux: Optional[Tensor] = None,
|
|
352
360
|
) -> Tensor:
|
|
353
361
|
return torch.empty_like(q)
|
|
354
362
|
|
|
@@ -371,6 +379,7 @@ def paged_causal_attn_prefill_kv_fp8(
|
|
|
371
379
|
k_scale: Tensor,
|
|
372
380
|
v_scale: Tensor,
|
|
373
381
|
mask: Optional[Tensor] = None,
|
|
382
|
+
s_aux: Optional[Tensor] = None,
|
|
374
383
|
) -> Tensor:
|
|
375
384
|
return torch.empty_like(q)
|
|
376
385
|
|
|
@@ -390,6 +399,7 @@ def paged_causal_attn_prefill_kv_fp8_fake(
|
|
|
390
399
|
k_scale: Tensor,
|
|
391
400
|
v_scale: Tensor,
|
|
392
401
|
mask: Optional[Tensor] = None,
|
|
402
|
+
s_aux: Optional[Tensor] = None,
|
|
393
403
|
) -> Tensor:
|
|
394
404
|
return torch.empty_like(q)
|
|
395
405
|
|
optimum/rbln/ops/flash_attn.py
CHANGED
|
@@ -198,6 +198,7 @@ def paged_flash_causal_attn_decode(
|
|
|
198
198
|
block_size: int,
|
|
199
199
|
partition: int,
|
|
200
200
|
mask: Optional[Tensor] = None,
|
|
201
|
+
s_aux: Optional[Tensor] = None,
|
|
201
202
|
) -> Tensor:
|
|
202
203
|
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
|
203
204
|
|
|
@@ -219,6 +220,7 @@ def paged_flash_causal_attn_decode_fake(
|
|
|
219
220
|
block_size: int,
|
|
220
221
|
partition: int,
|
|
221
222
|
mask: Optional[Tensor] = None,
|
|
223
|
+
s_aux: Optional[Tensor] = None,
|
|
222
224
|
) -> Tensor:
|
|
223
225
|
return torch.empty_like(q)
|
|
224
226
|
|
|
@@ -241,6 +243,7 @@ def paged_flash_causal_attn_decode_kv_fp8(
|
|
|
241
243
|
k_scale: Tensor,
|
|
242
244
|
v_scale: Tensor,
|
|
243
245
|
mask: Optional[Tensor] = None,
|
|
246
|
+
s_aux: Optional[Tensor] = None,
|
|
244
247
|
) -> Tensor:
|
|
245
248
|
return torch.empty_like(q)
|
|
246
249
|
|
|
@@ -260,6 +263,7 @@ def paged_flash_causal_attn_decode_kv_fp8_fake(
|
|
|
260
263
|
k_scale: Tensor,
|
|
261
264
|
v_scale: Tensor,
|
|
262
265
|
mask: Optional[Tensor] = None,
|
|
266
|
+
s_aux: Optional[Tensor] = None,
|
|
263
267
|
) -> Tensor:
|
|
264
268
|
return torch.empty_like(q)
|
|
265
269
|
|
|
@@ -281,6 +285,7 @@ def paged_flash_causal_attn_prefill(
|
|
|
281
285
|
partition: int,
|
|
282
286
|
is_bidirectional: bool,
|
|
283
287
|
mask: Optional[Tensor] = None,
|
|
288
|
+
s_aux: Optional[Tensor] = None,
|
|
284
289
|
) -> Tensor:
|
|
285
290
|
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
|
286
291
|
|
|
@@ -303,6 +308,7 @@ def paged_flash_causal_attn_prefill_fake(
|
|
|
303
308
|
partition: int,
|
|
304
309
|
is_bidirectional: bool,
|
|
305
310
|
mask: Optional[Tensor] = None,
|
|
311
|
+
s_aux: Optional[Tensor] = None,
|
|
306
312
|
) -> Tensor:
|
|
307
313
|
return torch.empty_like(q)
|
|
308
314
|
|
|
@@ -326,6 +332,7 @@ def paged_flash_causal_attn_prefill_kv_fp8(
|
|
|
326
332
|
k_scale: Tensor,
|
|
327
333
|
v_scale: Tensor,
|
|
328
334
|
mask: Optional[Tensor] = None,
|
|
335
|
+
s_aux: Optional[Tensor] = None,
|
|
329
336
|
) -> Tensor:
|
|
330
337
|
return torch.empty_like(q)
|
|
331
338
|
|
|
@@ -346,5 +353,6 @@ def paged_flash_causal_attn_prefill_kv_fp8_fake(
|
|
|
346
353
|
k_scale: Tensor,
|
|
347
354
|
v_scale: Tensor,
|
|
348
355
|
mask: Optional[Tensor] = None,
|
|
356
|
+
s_aux: Optional[Tensor] = None,
|
|
349
357
|
) -> Tensor:
|
|
350
358
|
return torch.empty_like(q)
|
optimum/rbln/ops/moe.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import Tensor
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@torch.library.custom_op(
|
|
22
|
+
"rbln_custom_ops::custom_moe_glu",
|
|
23
|
+
mutates_args=(),
|
|
24
|
+
)
|
|
25
|
+
def custom_moe_glu(
|
|
26
|
+
hidden_states: Tensor,
|
|
27
|
+
gate_proj_weight: Tensor,
|
|
28
|
+
up_proj_weight: Tensor,
|
|
29
|
+
down_proj_weight: Tensor,
|
|
30
|
+
router_logits: Tensor,
|
|
31
|
+
topk: int,
|
|
32
|
+
norm_topk_prob: bool,
|
|
33
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
34
|
+
up_proj_bias: Optional[Tensor] = None,
|
|
35
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
36
|
+
) -> Tensor:
|
|
37
|
+
"""
|
|
38
|
+
Customized MoE GLU operation.
|
|
39
|
+
|
|
40
|
+
Expected tensor shapes:
|
|
41
|
+
- hidden_states: [batch*seq_len, hidden_size]
|
|
42
|
+
- gate_proj_weight: [num_experts, hidden_size, intermediate_size]
|
|
43
|
+
- up_proj_weight: [num_experts, hidden_size, intermediate_size]
|
|
44
|
+
- down_proj_weight: [num_experts, intermediate_size, hidden_size]
|
|
45
|
+
- router_logits: [batch*seq_len, num_experts]
|
|
46
|
+
- topk: top k experts to select
|
|
47
|
+
- norm_topk_prob: whether to normalize the top k routing weights with softmax
|
|
48
|
+
- gate_proj_bias: [num_experts, intermediate_size]
|
|
49
|
+
- up_proj_bias: [num_experts, intermediate_size]
|
|
50
|
+
- down_proj_bias: [num_experts, hidden_size]
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
return torch.empty_like(hidden_states)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@custom_moe_glu.register_fake
|
|
60
|
+
def custom_moe_glu_fake(
|
|
61
|
+
hidden_states: Tensor,
|
|
62
|
+
gate_proj_weight: Tensor,
|
|
63
|
+
up_proj_weight: Tensor,
|
|
64
|
+
down_proj_weight: Tensor,
|
|
65
|
+
router_logits: Tensor,
|
|
66
|
+
topk: int,
|
|
67
|
+
norm_topk_prob: bool,
|
|
68
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
69
|
+
up_proj_bias: Optional[Tensor] = None,
|
|
70
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
71
|
+
) -> Tensor:
|
|
72
|
+
return torch.empty_like(hidden_states)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@torch.library.custom_op(
|
|
76
|
+
"rbln_custom_ops::custom_moe_ff",
|
|
77
|
+
mutates_args=(),
|
|
78
|
+
)
|
|
79
|
+
def custom_moe_ff(
|
|
80
|
+
hidden_states: Tensor,
|
|
81
|
+
gate_proj_weight: Tensor,
|
|
82
|
+
down_proj_weight: Tensor,
|
|
83
|
+
masked_routing_weight: Tensor,
|
|
84
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
85
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
86
|
+
) -> Tensor:
|
|
87
|
+
"""
|
|
88
|
+
Customized MoE FF operation.
|
|
89
|
+
|
|
90
|
+
Expected tensor shapes:
|
|
91
|
+
- hidden_states: [batch * seq_len, hidden_size]
|
|
92
|
+
- gate_proj_weight: [hidden_size, num_experts * intermediate_size]
|
|
93
|
+
- down_proj_weight: [num_experts * intermediate_size, hidden_size]
|
|
94
|
+
- masked_routing_weight: [batch * seq_len, num_experts]
|
|
95
|
+
- gate_proj_bias: [num_experts * intermediate_size]
|
|
96
|
+
- down_proj_bias: [hidden_size]
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
100
|
+
"""
|
|
101
|
+
return torch.empty_like(hidden_states)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@custom_moe_ff.register_fake
|
|
105
|
+
def custom_moe_ff_fake(
|
|
106
|
+
hidden_states: Tensor,
|
|
107
|
+
gate_proj_weight: Tensor,
|
|
108
|
+
down_proj_weight: Tensor,
|
|
109
|
+
masked_routing_weight: Tensor,
|
|
110
|
+
gate_proj_bias: Optional[Tensor] = None,
|
|
111
|
+
down_proj_bias: Optional[Tensor] = None,
|
|
112
|
+
) -> Tensor:
|
|
113
|
+
return torch.empty_like(hidden_states)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@torch.library.custom_op(
|
|
117
|
+
"rbln_custom_ops::custom_moe_glu_mxfp4",
|
|
118
|
+
mutates_args=(),
|
|
119
|
+
)
|
|
120
|
+
def custom_moe_glu_mxfp4(
|
|
121
|
+
hidden_states: Tensor,
|
|
122
|
+
gate_proj_blocks: Tensor,
|
|
123
|
+
gate_proj_scales: Tensor,
|
|
124
|
+
gate_proj_bias: Tensor,
|
|
125
|
+
up_proj_blocks: Tensor,
|
|
126
|
+
up_proj_scales: Tensor,
|
|
127
|
+
up_proj_bias: Tensor,
|
|
128
|
+
down_proj_blocks: Tensor,
|
|
129
|
+
down_proj_scales: Tensor,
|
|
130
|
+
down_proj_bias: Tensor,
|
|
131
|
+
router_logits: Tensor,
|
|
132
|
+
alpha: Tensor,
|
|
133
|
+
limit: Tensor,
|
|
134
|
+
k: int,
|
|
135
|
+
post_norm: bool,
|
|
136
|
+
) -> Tensor:
|
|
137
|
+
"""
|
|
138
|
+
Customized MoE GLU operation.
|
|
139
|
+
|
|
140
|
+
Expected tensor shapes:
|
|
141
|
+
- hidden_states: [batch*seq_len, hidden_size]
|
|
142
|
+
- gate_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
|
|
143
|
+
- gate_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
|
|
144
|
+
- gate_proj_bias: [num_experts, intermediate_size]
|
|
145
|
+
- up_proj_blocks: [num_experts, intermediate_size, hidden_size // 2]
|
|
146
|
+
- up_proj_scales: [num_experts, intermediate_size, hidden_size // 32]
|
|
147
|
+
- up_proj_bias: [num_experts, intermediate_size]
|
|
148
|
+
- down_proj_blocks: [num_experts, hidden_size, intermediate_size // 2]
|
|
149
|
+
- down_proj_scales: [num_experts, hidden_size, intermediate_size // 32]
|
|
150
|
+
- masked_routing_weight: [batch * seq_len, num_experts]
|
|
151
|
+
- expert_select_count: [num_experts]
|
|
152
|
+
- alpha: []
|
|
153
|
+
- limit: []
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Tensor: [batch * seq_len, hidden_size]
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
return torch.empty_like(hidden_states)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@custom_moe_glu_mxfp4.register_fake
|
|
163
|
+
def custom_moe_glu_mxfp4_fake(
|
|
164
|
+
hidden_states: Tensor,
|
|
165
|
+
gate_proj_blocks: Tensor,
|
|
166
|
+
gate_proj_scales: Tensor,
|
|
167
|
+
gate_proj_bias: Tensor,
|
|
168
|
+
up_proj_blocks: Tensor,
|
|
169
|
+
up_proj_scales: Tensor,
|
|
170
|
+
up_proj_bias: Tensor,
|
|
171
|
+
down_proj_blocks: Tensor,
|
|
172
|
+
down_proj_scales: Tensor,
|
|
173
|
+
down_proj_bias: Tensor,
|
|
174
|
+
router_logits: Tensor,
|
|
175
|
+
alpha: Tensor,
|
|
176
|
+
limit: Tensor,
|
|
177
|
+
k: int,
|
|
178
|
+
post_norm: bool,
|
|
179
|
+
) -> Tensor:
|
|
180
|
+
return torch.empty_like(hidden_states)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
import torch
|
|
17
19
|
from torch import Tensor
|
|
18
20
|
|
|
@@ -33,6 +35,7 @@ def paged_sliding_window_attn_prefill(
|
|
|
33
35
|
block_table: Tensor,
|
|
34
36
|
block_size: int,
|
|
35
37
|
is_bidirectional: bool,
|
|
38
|
+
s_aux: Optional[Tensor] = None,
|
|
36
39
|
) -> Tensor:
|
|
37
40
|
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
|
38
41
|
|
|
@@ -53,6 +56,7 @@ def paged_sliding_window_attn_prefill(
|
|
|
53
56
|
- cache_offset: [] - The valid length in the combined sequence of the KV cache and the current projected key states.
|
|
54
57
|
- scale: [] - Attention scale factor
|
|
55
58
|
- is_bidirectional: [] - Whether the attention is bidirectional
|
|
59
|
+
- s_aux: [num_attention_heads, sink_len] - auxiliary states for attention
|
|
56
60
|
Returns:
|
|
57
61
|
Tensor: attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
|
58
62
|
"""
|
|
@@ -72,6 +76,7 @@ def paged_sliding_window_attn_prefill_fake(
|
|
|
72
76
|
block_table: Tensor,
|
|
73
77
|
block_size: int,
|
|
74
78
|
is_bidirectional: bool,
|
|
79
|
+
s_aux: Optional[Tensor] = None,
|
|
75
80
|
) -> Tensor:
|
|
76
81
|
return torch.empty_like(q)
|
|
77
82
|
|
|
@@ -91,6 +96,8 @@ def paged_sliding_window_attn_decode(
|
|
|
91
96
|
scale: Tensor,
|
|
92
97
|
block_table: Tensor,
|
|
93
98
|
block_size: int,
|
|
99
|
+
attn_mask: Tensor,
|
|
100
|
+
s_aux: Optional[Tensor] = None,
|
|
94
101
|
) -> Tensor:
|
|
95
102
|
return torch.empty_like(q)
|
|
96
103
|
|
|
@@ -107,5 +114,7 @@ def paged_sliding_window_attn_decode_fake(
|
|
|
107
114
|
scale: Tensor,
|
|
108
115
|
block_table: Tensor,
|
|
109
116
|
block_size: int,
|
|
117
|
+
attn_mask: Tensor,
|
|
118
|
+
s_aux: Optional[Tensor] = None,
|
|
110
119
|
) -> Tensor:
|
|
111
120
|
return torch.empty_like(q)
|
|
@@ -78,6 +78,10 @@ _import_structure = {
|
|
|
78
78
|
"RBLNExaoneForCausalLMConfig",
|
|
79
79
|
"RBLNGemmaModel",
|
|
80
80
|
"RBLNGemmaModelConfig",
|
|
81
|
+
"RBLNGemma2ForCausalLM",
|
|
82
|
+
"RBLNGemma2ForCausalLMConfig",
|
|
83
|
+
"RBLNGemma2Model",
|
|
84
|
+
"RBLNGemma2ModelConfig",
|
|
81
85
|
"RBLNGemma3ForCausalLM",
|
|
82
86
|
"RBLNGemma3ForCausalLMConfig",
|
|
83
87
|
"RBLNGemma3ForConditionalGeneration",
|
|
@@ -88,6 +92,8 @@ _import_structure = {
|
|
|
88
92
|
"RBLNGPT2LMHeadModelConfig",
|
|
89
93
|
"RBLNGPT2Model",
|
|
90
94
|
"RBLNGPT2ModelConfig",
|
|
95
|
+
"RBLNGptOssForCausalLM",
|
|
96
|
+
"RBLNGptOssForCausalLMConfig",
|
|
91
97
|
"RBLNGroundingDinoDecoder",
|
|
92
98
|
"RBLNGroundingDinoDecoderConfig",
|
|
93
99
|
"RBLNGroundingDinoForObjectDetection",
|
|
@@ -110,6 +116,10 @@ _import_structure = {
|
|
|
110
116
|
"RBLNPegasusForConditionalGenerationConfig",
|
|
111
117
|
"RBLNPegasusModel",
|
|
112
118
|
"RBLNPegasusModelConfig",
|
|
119
|
+
"RBLNPaliGemmaForConditionalGeneration",
|
|
120
|
+
"RBLNPaliGemmaForConditionalGenerationConfig",
|
|
121
|
+
"RBLNPaliGemmaModel",
|
|
122
|
+
"RBLNPaliGemmaModelConfig",
|
|
113
123
|
"RBLNLlavaNextForConditionalGeneration",
|
|
114
124
|
"RBLNLlavaNextForConditionalGenerationConfig",
|
|
115
125
|
"RBLNLoRAAdapterConfig",
|
|
@@ -134,14 +144,22 @@ _import_structure = {
|
|
|
134
144
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
135
145
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
136
146
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
147
|
+
"RBLNQwen2_5_VLModel",
|
|
148
|
+
"RBLNQwen2_5_VLModelConfig",
|
|
137
149
|
"RBLNQwen2VisionTransformerPretrainedModel",
|
|
138
150
|
"RBLNQwen2VisionTransformerPretrainedModelConfig",
|
|
139
151
|
"RBLNQwen2VLForConditionalGeneration",
|
|
140
152
|
"RBLNQwen2VLForConditionalGenerationConfig",
|
|
153
|
+
"RBLNQwen2VLModel",
|
|
154
|
+
"RBLNQwen2VLModelConfig",
|
|
141
155
|
"RBLNQwen2Model",
|
|
142
156
|
"RBLNQwen2ModelConfig",
|
|
143
157
|
"RBLNQwen2ForCausalLM",
|
|
144
158
|
"RBLNQwen2ForCausalLMConfig",
|
|
159
|
+
"RBLNQwen2MoeForCausalLM",
|
|
160
|
+
"RBLNQwen2MoeForCausalLMConfig",
|
|
161
|
+
"RBLNQwen3MoeForCausalLM",
|
|
162
|
+
"RBLNQwen3MoeForCausalLMConfig",
|
|
145
163
|
"RBLNQwen3ForCausalLM",
|
|
146
164
|
"RBLNQwen3ForCausalLMConfig",
|
|
147
165
|
"RBLNQwen3Model",
|
|
@@ -234,6 +252,10 @@ if TYPE_CHECKING:
|
|
|
234
252
|
RBLNDPTForDepthEstimationConfig,
|
|
235
253
|
RBLNExaoneForCausalLM,
|
|
236
254
|
RBLNExaoneForCausalLMConfig,
|
|
255
|
+
RBLNGemma2ForCausalLM,
|
|
256
|
+
RBLNGemma2ForCausalLMConfig,
|
|
257
|
+
RBLNGemma2Model,
|
|
258
|
+
RBLNGemma2ModelConfig,
|
|
237
259
|
RBLNGemma3ForCausalLM,
|
|
238
260
|
RBLNGemma3ForCausalLMConfig,
|
|
239
261
|
RBLNGemma3ForConditionalGeneration,
|
|
@@ -246,6 +268,8 @@ if TYPE_CHECKING:
|
|
|
246
268
|
RBLNGPT2LMHeadModelConfig,
|
|
247
269
|
RBLNGPT2Model,
|
|
248
270
|
RBLNGPT2ModelConfig,
|
|
271
|
+
RBLNGptOssForCausalLM,
|
|
272
|
+
RBLNGptOssForCausalLMConfig,
|
|
249
273
|
RBLNGroundingDinoDecoder,
|
|
250
274
|
RBLNGroundingDinoDecoderConfig,
|
|
251
275
|
RBLNGroundingDinoEncoder,
|
|
@@ -276,6 +300,10 @@ if TYPE_CHECKING:
|
|
|
276
300
|
RBLNOPTForCausalLMConfig,
|
|
277
301
|
RBLNOPTModel,
|
|
278
302
|
RBLNOPTModelConfig,
|
|
303
|
+
RBLNPaliGemmaForConditionalGeneration,
|
|
304
|
+
RBLNPaliGemmaForConditionalGenerationConfig,
|
|
305
|
+
RBLNPaliGemmaModel,
|
|
306
|
+
RBLNPaliGemmaModelConfig,
|
|
279
307
|
RBLNPegasusForConditionalGeneration,
|
|
280
308
|
RBLNPegasusForConditionalGenerationConfig,
|
|
281
309
|
RBLNPegasusModel,
|
|
@@ -290,18 +318,26 @@ if TYPE_CHECKING:
|
|
|
290
318
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
291
319
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
292
320
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
321
|
+
RBLNQwen2_5_VLModel,
|
|
322
|
+
RBLNQwen2_5_VLModelConfig,
|
|
293
323
|
RBLNQwen2ForCausalLM,
|
|
294
324
|
RBLNQwen2ForCausalLMConfig,
|
|
295
325
|
RBLNQwen2Model,
|
|
296
326
|
RBLNQwen2ModelConfig,
|
|
327
|
+
RBLNQwen2MoeForCausalLM,
|
|
328
|
+
RBLNQwen2MoeForCausalLMConfig,
|
|
297
329
|
RBLNQwen2VisionTransformerPretrainedModel,
|
|
298
330
|
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
299
331
|
RBLNQwen2VLForConditionalGeneration,
|
|
300
332
|
RBLNQwen2VLForConditionalGenerationConfig,
|
|
333
|
+
RBLNQwen2VLModel,
|
|
334
|
+
RBLNQwen2VLModelConfig,
|
|
301
335
|
RBLNQwen3ForCausalLM,
|
|
302
336
|
RBLNQwen3ForCausalLMConfig,
|
|
303
337
|
RBLNQwen3Model,
|
|
304
338
|
RBLNQwen3ModelConfig,
|
|
339
|
+
RBLNQwen3MoeForCausalLM,
|
|
340
|
+
RBLNQwen3MoeForCausalLMConfig,
|
|
305
341
|
RBLNResNetForImageClassification,
|
|
306
342
|
RBLNResNetForImageClassificationConfig,
|
|
307
343
|
RBLNRobertaForMaskedLM,
|