optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,24 +12,17 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
16
|
-
|
|
17
|
-
from transformers import PretrainedConfig
|
|
18
15
|
|
|
19
16
|
from ....utils import logging
|
|
20
17
|
from ...models.decoderonly import (
|
|
21
18
|
RBLNDecoderOnlyModel,
|
|
22
19
|
RBLNDecoderOnlyModelForCausalLM,
|
|
23
|
-
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
24
20
|
)
|
|
25
21
|
from .qwen3_architecture import Qwen3Wrapper
|
|
26
22
|
|
|
27
23
|
|
|
28
24
|
logger = logging.get_logger(__name__)
|
|
29
25
|
|
|
30
|
-
if TYPE_CHECKING:
|
|
31
|
-
from transformers import PretrainedConfig
|
|
32
|
-
|
|
33
26
|
|
|
34
27
|
class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
35
28
|
"""
|
|
@@ -84,19 +77,6 @@ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
84
77
|
|
|
85
78
|
_decoder_wrapper_cls = Qwen3Wrapper
|
|
86
79
|
|
|
87
|
-
@classmethod
|
|
88
|
-
def _update_sliding_window_config(
|
|
89
|
-
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
90
|
-
):
|
|
91
|
-
# https://github.com/huggingface/transformers/issues/35896
|
|
92
|
-
# There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
|
|
93
|
-
# we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
|
|
94
|
-
|
|
95
|
-
rbln_config.cache_impl = "sliding_window"
|
|
96
|
-
rbln_config.sliding_window = model_config.sliding_window
|
|
97
|
-
rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
|
|
98
|
-
return rbln_config
|
|
99
|
-
|
|
100
80
|
def forward(self, *args, **kwargs):
|
|
101
81
|
kwargs["return_dict"] = True
|
|
102
82
|
return super().forward(*args, **kwargs)
|
|
@@ -22,10 +22,10 @@ class Qwen3Wrapper(DecoderOnlyWrapper):
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
class Qwen3Attention(DecoderOnlyAttention):
|
|
25
|
-
def __post_init__(self):
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
28
|
-
self.
|
|
29
|
-
self.o_proj =
|
|
30
|
-
self.q_norm =
|
|
31
|
-
self.k_norm =
|
|
25
|
+
def __post_init__(self, self_attn):
|
|
26
|
+
self.q_proj = self_attn.q_proj
|
|
27
|
+
self.k_proj = self_attn.k_proj
|
|
28
|
+
self.v_proj = self_attn.v_proj
|
|
29
|
+
self.o_proj = self_attn.o_proj
|
|
30
|
+
self.q_norm = self_attn.q_norm
|
|
31
|
+
self.k_norm = self_attn.k_norm
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_qwen3_moe import RBLNQwen3MoeForCausalLMConfig
|
|
16
|
+
from .modeling_qwen3_moe import RBLNQwen3MoeForCausalLM
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNQwen3MoeForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Qwen3 Moe models.
|
|
21
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
22
|
+
Example usage:
|
|
23
|
+
```python
|
|
24
|
+
from optimum.rbln import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
|
|
25
|
+
# Create a configuration object
|
|
26
|
+
config = RBLNQwen3MoeForCausalLMConfig(
|
|
27
|
+
batch_size=1,
|
|
28
|
+
max_seq_len=262144,
|
|
29
|
+
tensor_parallel_size=4
|
|
30
|
+
)
|
|
31
|
+
# Use the configuration with from_pretrained
|
|
32
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
33
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
34
|
+
export=True,
|
|
35
|
+
rbln_config=config
|
|
36
|
+
)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
16
|
+
from .qwen3_moe_architecture import Qwen3MoeWrapper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RBLNQwen3MoeForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
20
|
+
"""
|
|
21
|
+
The Qwen3 Moe is a Mixture-of-Experts (MoE) variant of Qwen3, available as a base model and an aligned chat model.
|
|
22
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
23
|
+
A class to convert and run pre-trained transformers based Qwen3MoeForCausalLM model on RBLN devices.
|
|
24
|
+
It implements the methods to convert a pre-trained transformers Qwen3MoeForCausalLM model into a RBLN transformer model by:
|
|
25
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
26
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
27
|
+
**Configuration:**
|
|
28
|
+
This model uses [`RBLNQwen3MoeForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
29
|
+
the `rbln_config` parameter should be an instance of [`RBLNQwen3MoeForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
30
|
+
See the [`RBLNQwen3MoeForCausalLMConfig`] class for all available configuration options.
|
|
31
|
+
Examples:
|
|
32
|
+
```python
|
|
33
|
+
from optimum.rbln import RBLNQwen3MoeForCausalLM
|
|
34
|
+
# Simple usage using rbln_* arguments
|
|
35
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
36
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
37
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
38
|
+
export=True,
|
|
39
|
+
rbln_batch_size=1,
|
|
40
|
+
rbln_tensor_parallel_size=4,
|
|
41
|
+
)
|
|
42
|
+
# Using a config dictionary
|
|
43
|
+
rbln_config = {
|
|
44
|
+
"batch_size": 1,
|
|
45
|
+
"max_seq_len": 262144,
|
|
46
|
+
"tensor_parallel_size": 4,
|
|
47
|
+
}
|
|
48
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
49
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
50
|
+
export=True,
|
|
51
|
+
rbln_config=rbln_config
|
|
52
|
+
)
|
|
53
|
+
# Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
|
|
54
|
+
from optimum.rbln import RBLNQwen3MoeForCausalLMConfig
|
|
55
|
+
config = RBLNQwen3MoeForCausalLMConfig(
|
|
56
|
+
batch_size=1,
|
|
57
|
+
max_seq_len=262144,
|
|
58
|
+
tensor_parallel_size=4
|
|
59
|
+
)
|
|
60
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
61
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
62
|
+
export=True,
|
|
63
|
+
rbln_config=config
|
|
64
|
+
)
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_decoder_wrapper_cls = Qwen3MoeWrapper
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
|
|
21
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Qwen3MoeWrapper(DecoderOnlyWrapper):
|
|
25
|
+
def get_rbln_layer_class(self):
|
|
26
|
+
return Qwen3MoeLayer
|
|
27
|
+
|
|
28
|
+
def get_rbln_attn_class(self):
|
|
29
|
+
return Qwen3MoeAttention
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Qwen3MoeAttention(DecoderOnlyAttention):
|
|
33
|
+
def __post_init__(self, self_attn):
|
|
34
|
+
self.q_proj = self_attn.q_proj
|
|
35
|
+
self.k_proj = self_attn.k_proj
|
|
36
|
+
self.v_proj = self_attn.v_proj
|
|
37
|
+
self.o_proj = self_attn.o_proj
|
|
38
|
+
self.q_norm = self_attn.q_norm
|
|
39
|
+
self.k_norm = self_attn.k_norm
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Qwen3MoeLayer(DecoderOnlyLayer):
|
|
43
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
|
|
44
|
+
super().__init__(layer, self_attn, lora_config)
|
|
45
|
+
self.mlp = (
|
|
46
|
+
Qwen3MoeSparseMoeBlock(layer.mlp)
|
|
47
|
+
if layer.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock"
|
|
48
|
+
else layer.mlp
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def get_mlp(self) -> nn.Module:
|
|
52
|
+
return self.mlp
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
56
|
+
def __init__(self, model: nn.Module):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.num_experts = model.num_experts
|
|
59
|
+
self.top_k = model.top_k
|
|
60
|
+
self.norm_topk_prob = model.norm_topk_prob
|
|
61
|
+
self.gate = model.gate
|
|
62
|
+
self.experts = Qwen3MoeMLP(model.experts, self.top_k, self.norm_topk_prob)
|
|
63
|
+
|
|
64
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
66
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
67
|
+
|
|
68
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
69
|
+
router_logits = self.gate(hidden_states)
|
|
70
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
|
71
|
+
|
|
72
|
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
73
|
+
return final_hidden_states
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Qwen3MoeMLP(nn.Module):
|
|
77
|
+
def __init__(self, expert_list, top_k, norm_topk_prob):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.hidden_size = expert_list[0].hidden_size
|
|
80
|
+
self.intermediate_size = expert_list[0].intermediate_size
|
|
81
|
+
self.top_k = top_k
|
|
82
|
+
self.norm_topk_prob = norm_topk_prob
|
|
83
|
+
self.num_experts = len(expert_list)
|
|
84
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
|
|
85
|
+
self.up_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
|
|
86
|
+
self.down_proj = nn.Linear(self.num_experts * self.intermediate_size, self.hidden_size, bias=False)
|
|
87
|
+
self.gate_proj.weight.data = torch.stack([expert.gate_proj.weight.data for expert in expert_list], dim=0)
|
|
88
|
+
self.up_proj.weight.data = torch.stack([expert.up_proj.weight.data for expert in expert_list], dim=0)
|
|
89
|
+
self.down_proj.weight.data = torch.stack([expert.down_proj.weight.data for expert in expert_list], dim=0)
|
|
90
|
+
|
|
91
|
+
def forward(self, x, router_logits):
|
|
92
|
+
return torch.ops.rbln_custom_ops.custom_moe_glu(
|
|
93
|
+
x,
|
|
94
|
+
self.gate_proj.weight,
|
|
95
|
+
self.up_proj.weight,
|
|
96
|
+
self.down_proj.weight,
|
|
97
|
+
router_logits,
|
|
98
|
+
self.top_k,
|
|
99
|
+
self.norm_topk_prob,
|
|
100
|
+
)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
|
17
19
|
|
|
18
20
|
|
|
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
|
|
|
23
25
|
This configuration class stores the configuration parameters specific to
|
|
24
26
|
RBLN-optimized ResNet models for image classification tasks.
|
|
25
27
|
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
|
33
|
+
Can be an integer for square images or a tuple (height, width).
|
|
34
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
35
|
+
output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
|
|
36
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If batch_size is not a positive integer.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(**kwargs)
|
|
42
|
+
self.output_hidden_states = output_hidden_states
|
|
@@ -13,7 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForImageClassification
|
|
22
|
+
from .configuration_resnet import RBLNResNetForImageClassificationConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
17
27
|
|
|
18
28
|
|
|
19
29
|
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
|
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
|
|
24
34
|
on RBLN devices, supporting image classification with convolutional neural networks
|
|
25
35
|
designed for computer vision tasks.
|
|
26
36
|
"""
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def _update_rbln_config(
|
|
40
|
+
cls,
|
|
41
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
42
|
+
model: Optional["PreTrainedModel"] = None,
|
|
43
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
44
|
+
rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
|
|
45
|
+
) -> "RBLNResNetForImageClassificationConfig":
|
|
46
|
+
if rbln_config.output_hidden_states is None:
|
|
47
|
+
rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
|
|
48
|
+
|
|
49
|
+
rbln_config = super()._update_rbln_config(
|
|
50
|
+
preprocessors=preprocessors,
|
|
51
|
+
model=model,
|
|
52
|
+
model_config=model_config,
|
|
53
|
+
rbln_config=rbln_config,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return rbln_config
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def _wrap_model_if_needed(
|
|
60
|
+
cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
|
|
61
|
+
) -> torch.nn.Module:
|
|
62
|
+
class _ResNetForImageClassification(torch.nn.Module):
|
|
63
|
+
def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.model = model
|
|
66
|
+
self.output_hidden_states = output_hidden_states
|
|
67
|
+
|
|
68
|
+
def forward(self, *args, **kwargs):
|
|
69
|
+
output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
|
|
70
|
+
return output
|
|
71
|
+
|
|
72
|
+
return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
|
|
73
|
+
|
|
74
|
+
def forward(
|
|
75
|
+
self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
|
|
76
|
+
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
|
77
|
+
"""
|
|
78
|
+
Foward pass for the RBLN-optimized ResNet model for image classification.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
|
|
82
|
+
output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
|
|
83
|
+
See hidden_states under returned tensors for more details.
|
|
84
|
+
return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
|
|
88
|
+
"""
|
|
89
|
+
output_hidden_states = (
|
|
90
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
96
|
+
f"Please compile again with the correct argument."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
|
|
16
21
|
|
|
17
22
|
|
|
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
|
|
26
31
|
|
|
27
32
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
28
33
|
|
|
34
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
|
|
35
|
+
"""
|
|
36
|
+
Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
40
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
|
|
44
|
+
"""
|
|
45
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
46
|
+
|
|
29
47
|
|
|
30
48
|
class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
31
49
|
"""
|
|
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
|
37
55
|
"""
|
|
38
56
|
|
|
39
57
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
58
|
+
|
|
59
|
+
def forward(
|
|
60
|
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
|
|
61
|
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
|
62
|
+
"""
|
|
63
|
+
Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
67
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
|
|
71
|
+
"""
|
|
72
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.deprecation import deprecate_kwarg
|
|
18
19
|
from ....utils.logging import get_logger
|
|
19
20
|
|
|
20
21
|
|
|
@@ -24,13 +25,13 @@ logger = get_logger()
|
|
|
24
25
|
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
25
26
|
support_paged_attention = None
|
|
26
27
|
|
|
28
|
+
@deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
29
31
|
batch_size: Optional[int] = None,
|
|
30
32
|
enc_max_seq_len: Optional[int] = None,
|
|
31
33
|
dec_max_seq_len: Optional[int] = None,
|
|
32
34
|
use_attention_mask: Optional[bool] = None,
|
|
33
|
-
pad_token_id: Optional[int] = None,
|
|
34
35
|
kvcache_num_blocks: Optional[int] = None,
|
|
35
36
|
kvcache_block_size: Optional[int] = None,
|
|
36
37
|
**kwargs: Any,
|
|
@@ -41,7 +42,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
|
41
42
|
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
|
42
43
|
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
|
43
44
|
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
|
44
|
-
pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
|
|
45
45
|
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
|
46
46
|
PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
|
|
47
47
|
kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
|
|
@@ -61,8 +61,6 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
|
61
61
|
|
|
62
62
|
self.use_attention_mask = use_attention_mask
|
|
63
63
|
|
|
64
|
-
self.pad_token_id = pad_token_id
|
|
65
|
-
|
|
66
64
|
if self.support_paged_attention:
|
|
67
65
|
self.kvcache_num_blocks = kvcache_num_blocks
|
|
68
66
|
self.kvcache_block_size = kvcache_block_size
|
|
@@ -20,8 +20,9 @@ import rebel
|
|
|
20
20
|
import torch
|
|
21
21
|
from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
|
23
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
23
24
|
from transformers.generation.utils import GenerationMixin
|
|
24
|
-
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
|
|
25
26
|
|
|
26
27
|
from ....configuration_utils import RBLNCompileConfig
|
|
27
28
|
from ....modeling import RBLNModel
|
|
@@ -33,7 +34,7 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
|
|
|
33
34
|
logger = get_logger(__name__)
|
|
34
35
|
|
|
35
36
|
if TYPE_CHECKING:
|
|
36
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer,
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
@@ -140,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
|
140
141
|
@classmethod
|
|
141
142
|
@torch.inference_mode()
|
|
142
143
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
143
|
-
wrapped_model = cls.
|
|
144
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
144
145
|
|
|
145
146
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
146
147
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -209,8 +210,8 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
|
209
210
|
if not cls.support_causal_attn:
|
|
210
211
|
rbln_config.use_attention_mask = True
|
|
211
212
|
|
|
212
|
-
n_layer = getattr(model_config, "decoder_layers", None) or
|
|
213
|
-
n_head = getattr(model_config, "decoder_attention_heads", None) or
|
|
213
|
+
n_layer = getattr(model_config, "decoder_layers", None) or model_config.num_layers
|
|
214
|
+
n_head = getattr(model_config, "decoder_attention_heads", None) or model_config.num_heads
|
|
214
215
|
d_kv = (
|
|
215
216
|
model_config.d_kv
|
|
216
217
|
if hasattr(model_config, "d_kv")
|
|
@@ -221,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
|
221
222
|
model_config, "max_position_embeddings", None
|
|
222
223
|
)
|
|
223
224
|
|
|
224
|
-
pad_token_id = getattr(model_config, "pad_token_id", None)
|
|
225
|
-
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
|
226
|
-
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
|
227
|
-
pad_token_id = pad_token_id or -1
|
|
228
|
-
rbln_config.pad_token_id = pad_token_id
|
|
229
|
-
|
|
230
225
|
if rbln_config.enc_max_seq_len is None:
|
|
231
226
|
enc_max_seq_len = max_position_embeddings
|
|
232
227
|
for tokenizer in preprocessors:
|
|
@@ -432,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
|
432
427
|
inputs_tensor = torch.nn.functional.pad(
|
|
433
428
|
inputs_tensor,
|
|
434
429
|
(0, self.rbln_config.enc_max_seq_len - input_len),
|
|
435
|
-
value=self.
|
|
430
|
+
value=self.config.pad_token_id,
|
|
436
431
|
)
|
|
437
432
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
|
438
433
|
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
|
@@ -451,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
|
451
446
|
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
|
452
447
|
|
|
453
448
|
return model_kwargs
|
|
449
|
+
|
|
450
|
+
def generate(
|
|
451
|
+
self,
|
|
452
|
+
input_ids: torch.LongTensor,
|
|
453
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
454
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
455
|
+
**kwargs,
|
|
456
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
457
|
+
"""
|
|
458
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
459
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
463
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
464
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
465
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
466
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
467
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Generates sequences of token ids for models with a language modeling head.
|
|
471
|
+
"""
|
|
472
|
+
if generation_config is not None:
|
|
473
|
+
kwargs["generation_config"] = generation_config
|
|
474
|
+
if attention_mask is not None:
|
|
475
|
+
kwargs["attention_mask"] = attention_mask
|
|
476
|
+
|
|
477
|
+
return super().generate(input_ids, **kwargs)
|
|
@@ -268,13 +268,12 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
|
268
268
|
|
|
269
269
|
def __init__(self, model, layers, **kwargs):
|
|
270
270
|
super().__init__()
|
|
271
|
-
self._original_mod = model
|
|
272
271
|
self.layers = nn.ModuleList(layers)
|
|
273
272
|
self.embed_tokens = model.embed_tokens
|
|
274
|
-
self.final_layer_norm = getattr(model, "final_layer_norm", None)
|
|
275
|
-
self.__post_init__(**kwargs)
|
|
273
|
+
self.final_layer_norm = getattr(model, "final_layer_norm", None) or getattr(model, "layer_norm", None)
|
|
274
|
+
self.__post_init__(model, **kwargs)
|
|
276
275
|
|
|
277
|
-
def __post_init__(self, **kwargs):
|
|
276
|
+
def __post_init__(self, model: nn.Module, **kwargs):
|
|
278
277
|
"""
|
|
279
278
|
Abstract method intended to be overridden by subclasses to modify or override
|
|
280
279
|
the attributes of the original model after initialization.
|
|
@@ -344,12 +343,11 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
|
344
343
|
|
|
345
344
|
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
346
345
|
super().__init__()
|
|
347
|
-
self._original_mod = decoder_layer
|
|
348
346
|
self.self_attn = self_attn
|
|
349
347
|
self.cross_attn = cross_attn
|
|
350
|
-
self.__post_init__()
|
|
348
|
+
self.__post_init__(decoder_layer)
|
|
351
349
|
|
|
352
|
-
def __post_init__(self, **kwargs):
|
|
350
|
+
def __post_init__(self, decoder_layer: nn.Module, **kwargs):
|
|
353
351
|
"""
|
|
354
352
|
Abstract method intended to be overridden by subclasses to modify or override
|
|
355
353
|
the attributes of the original model after initialization.
|
|
@@ -423,10 +421,9 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
|
423
421
|
class Seq2SeqSelfAttention(nn.Module):
|
|
424
422
|
def __init__(self, attn, **kwargs):
|
|
425
423
|
super().__init__()
|
|
426
|
-
self.
|
|
427
|
-
self.__post_init__(**kwargs)
|
|
424
|
+
self.__post_init__(attn, **kwargs)
|
|
428
425
|
|
|
429
|
-
def __post_init__(self, **kwargs):
|
|
426
|
+
def __post_init__(self, attn: nn.Module, **kwargs):
|
|
430
427
|
"""
|
|
431
428
|
Abstract method intended to be overridden by subclasses to modify or override
|
|
432
429
|
the attributes of the original model after initialization.
|
|
@@ -495,8 +492,13 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
|
495
492
|
class Seq2SeqCrossAttention(nn.Module):
|
|
496
493
|
def __init__(self, attn, **kwargs):
|
|
497
494
|
super().__init__()
|
|
498
|
-
self.
|
|
499
|
-
|
|
495
|
+
self.__post_init__(attn, **kwargs)
|
|
496
|
+
|
|
497
|
+
def __post_init__(self, attn: nn.Module, **kwargs):
|
|
498
|
+
"""
|
|
499
|
+
Optional post-init hook for subclasses (e.g., to register q/k/v/out projections).
|
|
500
|
+
"""
|
|
501
|
+
pass
|
|
500
502
|
|
|
501
503
|
def forward(
|
|
502
504
|
self,
|