optimum-rbln 0.2.1a0__py3-none-any.whl → 0.2.1a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +3 -10
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +1 -10
- optimum/rbln/diffusers/modeling_diffusers.py +1 -10
- optimum/rbln/diffusers/models/__init__.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
- optimum/rbln/diffusers/models/controlnet.py +1 -10
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
- optimum/rbln/diffusers/models/unets/__init__.py +1 -10
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
- optimum/rbln/diffusers/pipelines/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
- optimum/rbln/modeling.py +1 -10
- optimum/rbln/modeling_base.py +1 -10
- optimum/rbln/modeling_config.py +1 -10
- optimum/rbln/ops/__init__.py +1 -10
- optimum/rbln/ops/attn.py +9 -18
- optimum/rbln/ops/flash_attn.py +5 -14
- optimum/rbln/ops/kv_cache_update.py +1 -10
- optimum/rbln/transformers/__init__.py +5 -12
- optimum/rbln/transformers/modeling_alias.py +1 -14
- optimum/rbln/transformers/modeling_generic.py +40 -21
- optimum/rbln/transformers/modeling_rope_utils.py +28 -0
- optimum/rbln/transformers/models/__init__.py +3 -12
- optimum/rbln/transformers/models/auto/__init__.py +1 -10
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
- optimum/rbln/transformers/models/bart/__init__.py +1 -10
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
- optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
- optimum/rbln/transformers/models/bert/__init__.py +2 -11
- optimum/rbln/transformers/models/bert/modeling_bert.py +23 -13
- optimum/rbln/transformers/models/clip/__init__.py +1 -10
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +54 -69
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
- optimum/rbln/transformers/models/dpt/__init__.py +1 -10
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
- optimum/rbln/transformers/models/exaone/__init__.py +1 -10
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
- optimum/rbln/transformers/models/gemma/__init__.py +1 -10
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
- optimum/rbln/transformers/models/llama/__init__.py +1 -10
- optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
- optimum/rbln/transformers/models/midm/__init__.py +1 -10
- optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
- optimum/rbln/transformers/models/mistral/__init__.py +1 -10
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
- optimum/rbln/transformers/models/phi/__init__.py +1 -10
- optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +16 -42
- optimum/rbln/transformers/models/t5/__init__.py +1 -10
- optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
- optimum/rbln/transformers/models/t5/t5_architecture.py +30 -16
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
- optimum/rbln/transformers/models/whisper/__init__.py +1 -10
- optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
- optimum/rbln/transformers/utils/rbln_quantization.py +6 -10
- optimum/rbln/utils/__init__.py +1 -10
- optimum/rbln/utils/decorator_utils.py +1 -10
- optimum/rbln/utils/hub.py +1 -10
- optimum/rbln/utils/import_utils.py +1 -10
- optimum/rbln/utils/logging.py +1 -10
- optimum/rbln/utils/model_utils.py +1 -10
- optimum/rbln/utils/runtime_utils.py +1 -10
- optimum/rbln/utils/save_utils.py +2 -10
- optimum/rbln/utils/submodule.py +1 -10
- {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/METADATA +6 -4
- optimum_rbln-0.2.1a2.dist-info/RECORD +114 -0
- optimum_rbln-0.2.1a2.dist-info/licenses/LICENSE +201 -0
- optimum_rbln-0.2.1a0.dist-info/RECORD +0 -114
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +0 -288
- {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import Tuple
|
25
16
|
|
26
17
|
import torch
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
17
|
|
@@ -67,23 +58,33 @@ class RBLNBartModel(RBLNModel):
|
|
67
58
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
68
59
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
69
60
|
|
61
|
+
signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
|
62
|
+
|
70
63
|
if rbln_model_input_names is None:
|
71
64
|
for tokenizer in preprocessors:
|
72
65
|
if hasattr(tokenizer, "model_input_names"):
|
73
|
-
rbln_model_input_names = tokenizer.model_input_names
|
66
|
+
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
74
67
|
# BartModel's forward() does not take token_type_ids as input.
|
75
68
|
# (Added because some of the tokenizers includes 'token_type_ids')
|
76
69
|
if "token_type_ids" in rbln_model_input_names:
|
77
70
|
rbln_model_input_names.remove("token_type_ids")
|
71
|
+
|
72
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
73
|
+
if invalid_params:
|
74
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
78
75
|
break
|
79
76
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
80
77
|
rbln_model_input_names = cls.rbln_model_input_names
|
81
78
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
82
|
-
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
83
79
|
raise ValueError(
|
84
80
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
85
|
-
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(
|
81
|
+
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
|
86
82
|
)
|
83
|
+
else:
|
84
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
85
|
+
if invalid_params:
|
86
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
87
|
+
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
87
88
|
|
88
89
|
if rbln_batch_size is None:
|
89
90
|
rbln_batch_size = 1
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
|
-
from .modeling_bert import RBLNBertModel
|
15
|
+
from .modeling_bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
import logging
|
26
17
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
@@ -29,6 +20,7 @@ from transformers import PretrainedConfig
|
|
29
20
|
|
30
21
|
from ....modeling import RBLNModel
|
31
22
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
23
|
+
from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForQuestionAnswering
|
32
24
|
|
33
25
|
|
34
26
|
logger = logging.getLogger(__name__)
|
@@ -64,19 +56,29 @@ class RBLNBertModel(RBLNModel):
|
|
64
56
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
65
57
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
66
58
|
|
59
|
+
signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
|
60
|
+
|
67
61
|
if rbln_model_input_names is None:
|
68
62
|
for tokenizer in preprocessors:
|
69
63
|
if hasattr(tokenizer, "model_input_names"):
|
70
|
-
rbln_model_input_names = tokenizer.model_input_names
|
64
|
+
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
65
|
+
|
66
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
67
|
+
if invalid_params:
|
68
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
71
69
|
break
|
72
70
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
73
71
|
rbln_model_input_names = cls.rbln_model_input_names
|
74
72
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
75
|
-
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
76
73
|
raise ValueError(
|
77
74
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
78
|
-
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(
|
75
|
+
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
|
79
76
|
)
|
77
|
+
else:
|
78
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
79
|
+
if invalid_params:
|
80
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
81
|
+
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
80
82
|
|
81
83
|
if rbln_batch_size is None:
|
82
84
|
rbln_batch_size = 1
|
@@ -96,3 +98,11 @@ class RBLNBertModel(RBLNModel):
|
|
96
98
|
|
97
99
|
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
98
100
|
return rbln_config
|
101
|
+
|
102
|
+
|
103
|
+
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
104
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
105
|
+
|
106
|
+
|
107
|
+
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
108
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import logging
|
25
16
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import math
|
25
16
|
from typing import List, Optional, Tuple
|
26
17
|
|
@@ -553,15 +544,19 @@ class DecoderOnlyAttention(nn.Module):
|
|
553
544
|
super().__init__()
|
554
545
|
self._original_mod = self_attn
|
555
546
|
self.layer_idx = self_attn.layer_idx
|
556
|
-
self.num_heads = self._original_mod
|
547
|
+
self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
|
548
|
+
self._original_mod.config, "num_attention_heads"
|
549
|
+
)
|
557
550
|
self.head_dim = self._original_mod.head_dim
|
558
551
|
self._phase = "prefill"
|
559
552
|
self.scale = torch.tensor(self.get_attn_scale())
|
560
553
|
|
561
554
|
if hasattr(self._original_mod, "num_key_value_heads"):
|
562
555
|
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
556
|
+
elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
|
557
|
+
self.num_key_value_heads = self._original_mod.config.num_key_value_heads
|
563
558
|
else:
|
564
|
-
self.num_key_value_heads = self.
|
559
|
+
self.num_key_value_heads = self.num_heads
|
565
560
|
|
566
561
|
self.attention = self.get_attention()
|
567
562
|
self.__post_init__()
|
@@ -631,33 +626,21 @@ class DecoderOnlyAttention(nn.Module):
|
|
631
626
|
if batch_size > 1 and self.phase == "prefill":
|
632
627
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
633
628
|
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
batch_position=b if self.phase == "decode" else batch_position,
|
650
|
-
seq_position=seq_position,
|
651
|
-
scale=self.scale,
|
652
|
-
)
|
653
|
-
_key_states.append(key_state)
|
654
|
-
_value_states.append(value_state)
|
655
|
-
_attn_outputs.append(attn_output)
|
656
|
-
key_states = torch.cat(_key_states, dim=0)
|
657
|
-
value_states = torch.cat(_value_states, dim=0)
|
658
|
-
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
659
|
-
|
660
|
-
attn_outputs = self.o_proj(attn_outputs)
|
629
|
+
attn_output, key_state, value_state = self.attention(
|
630
|
+
query_states,
|
631
|
+
key_states,
|
632
|
+
value_states,
|
633
|
+
attention_mask,
|
634
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
635
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
636
|
+
batch_position=None if self.phase == "decode" else batch_position,
|
637
|
+
seq_position=seq_positions,
|
638
|
+
scale=self.scale,
|
639
|
+
)
|
640
|
+
key_states = key_state
|
641
|
+
value_states = value_state
|
642
|
+
|
643
|
+
attn_outputs = self.o_proj(attn_output)
|
661
644
|
past_key_values[self.layer_idx] = key_states, value_states
|
662
645
|
return attn_outputs, past_key_values
|
663
646
|
|
@@ -703,8 +686,13 @@ class AttentionOp(nn.Module):
|
|
703
686
|
value_state = value_state.unsqueeze(2)
|
704
687
|
attn_mask = attn_mask.unsqueeze(2)
|
705
688
|
|
689
|
+
if self.phase == "decode":
|
690
|
+
batch_size = key_state.shape[0]
|
691
|
+
else:
|
692
|
+
batch_size = 1
|
693
|
+
|
706
694
|
query_state = query_state.view(
|
707
|
-
|
695
|
+
batch_size,
|
708
696
|
self.num_key_value_heads,
|
709
697
|
self.num_heads // self.num_key_value_heads,
|
710
698
|
-1, # seq len
|
@@ -736,9 +724,9 @@ class AttentionOp(nn.Module):
|
|
736
724
|
scale,
|
737
725
|
)
|
738
726
|
|
739
|
-
attn_output = attn_output.view(
|
727
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
740
728
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
741
|
-
attn_output = attn_output.reshape(
|
729
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
742
730
|
|
743
731
|
return attn_output, key_state.squeeze(2), value_state.squeeze(2)
|
744
732
|
|
@@ -867,31 +855,23 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
867
855
|
if cos is not None and sin is not None:
|
868
856
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
869
857
|
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
)
|
886
|
-
_key_states.append(key_state)
|
887
|
-
_value_states.append(value_state)
|
888
|
-
_attn_outputs.append(attn_output)
|
889
|
-
key_states = torch.cat(_key_states, dim=0)
|
890
|
-
value_states = torch.cat(_value_states, dim=0)
|
891
|
-
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
892
|
-
|
893
|
-
attn_outputs = self.o_proj(attn_outputs)
|
858
|
+
attn_output, key_state, value_state = self.attention(
|
859
|
+
query_states,
|
860
|
+
key_states,
|
861
|
+
value_states,
|
862
|
+
attention_mask,
|
863
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
864
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
865
|
+
batch_position=None if self.phase == "decode" else batch_position,
|
866
|
+
seq_position=seq_positions,
|
867
|
+
scale=self.scale,
|
868
|
+
)
|
869
|
+
key_states = key_state
|
870
|
+
value_states = value_state
|
871
|
+
|
872
|
+
attn_outputs = self.o_proj(attn_output)
|
894
873
|
past_key_values[self.layer_idx] = key_states, value_states
|
874
|
+
|
895
875
|
return attn_outputs, past_key_values
|
896
876
|
|
897
877
|
|
@@ -917,8 +897,13 @@ class FlashAttentionOp(AttentionOp):
|
|
917
897
|
value_state = value_state.unsqueeze(2)
|
918
898
|
attn_mask = attn_mask.unsqueeze(2)
|
919
899
|
|
900
|
+
if self.phase == "decode":
|
901
|
+
batch_size = key_state.shape[0]
|
902
|
+
else:
|
903
|
+
batch_size = 1
|
904
|
+
|
920
905
|
query_state = query_state.view(
|
921
|
-
|
906
|
+
batch_size,
|
922
907
|
self.num_key_value_heads,
|
923
908
|
self.num_heads // self.num_key_value_heads,
|
924
909
|
-1, # seq len
|
@@ -952,8 +937,8 @@ class FlashAttentionOp(AttentionOp):
|
|
952
937
|
)
|
953
938
|
|
954
939
|
# reshape for removing repeat_kv
|
955
|
-
attn_output = attn_output.view(
|
940
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
956
941
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
957
|
-
attn_output = attn_output.reshape(
|
942
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
958
943
|
|
959
944
|
return attn_output, key_state, value_state
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
from dataclasses import dataclass
|
26
17
|
from pathlib import Path
|
@@ -218,14 +209,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
218
209
|
|
219
210
|
@classmethod
|
220
211
|
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
212
|
+
logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
|
213
|
+
|
221
214
|
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
222
215
|
rbln_quantization = rbln_kwargs.get("quantization", None)
|
223
|
-
|
224
216
|
if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
|
225
217
|
model = cls.get_quantized_model(*args, **kwargs)
|
226
218
|
else:
|
227
219
|
model = super().get_pytorch_model(*args, **kwargs)
|
228
220
|
|
221
|
+
logger.debug("Loaded the LLM model to the CPU.")
|
229
222
|
return model
|
230
223
|
|
231
224
|
@classmethod
|
@@ -297,8 +290,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
297
290
|
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
298
291
|
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
299
292
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
293
|
+
rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
|
294
|
+
|
295
|
+
if rbln_prefill_chunk_size is None:
|
296
|
+
rbln_prefill_chunk_size = 128
|
297
|
+
elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
|
298
|
+
raise ValueError(
|
299
|
+
f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
|
300
|
+
)
|
300
301
|
|
301
|
-
prefill_chunk_size = 128
|
302
302
|
if rbln_max_seq_len is None:
|
303
303
|
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
304
304
|
model_config, "n_positions", None
|
@@ -369,7 +369,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
369
369
|
|
370
370
|
prefill_input_info = get_input_info(
|
371
371
|
batch_size=1,
|
372
|
-
query_length=
|
372
|
+
query_length=rbln_prefill_chunk_size,
|
373
373
|
use_inputs_embeds=rbln_use_inputs_embeds,
|
374
374
|
hidden_size=hidden_size,
|
375
375
|
)
|
@@ -393,7 +393,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
393
393
|
{
|
394
394
|
"max_seq_len": rbln_max_seq_len,
|
395
395
|
"batch_size": rbln_batch_size,
|
396
|
-
"prefill_chunk_size":
|
396
|
+
"prefill_chunk_size": rbln_prefill_chunk_size,
|
397
397
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
398
398
|
"kvcache_partition_len": rbln_kvcache_partition_len,
|
399
399
|
"attn_impl": rbln_attn_impl,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_dpt import RBLNDPTForDepthEstimation
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import logging
|
25
16
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import os
|
25
16
|
from os import environ
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import TYPE_CHECKING
|
25
16
|
|
26
17
|
import torch.nn as nn
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
|
25
16
|
from transformers import AutoModelForCausalLM
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_gemma import RBLNGemmaForCausalLM
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import TYPE_CHECKING
|
25
16
|
|
26
17
|
from ...models.decoderonly.decoderonly_architecture import (
|