optimum-rbln 0.2.1a0__py3-none-any.whl → 0.2.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +1 -10
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +1 -10
- optimum/rbln/diffusers/modeling_diffusers.py +1 -10
- optimum/rbln/diffusers/models/__init__.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
- optimum/rbln/diffusers/models/controlnet.py +1 -10
- optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
- optimum/rbln/diffusers/models/unets/__init__.py +1 -10
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
- optimum/rbln/diffusers/pipelines/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
- optimum/rbln/modeling.py +1 -10
- optimum/rbln/modeling_base.py +1 -10
- optimum/rbln/modeling_config.py +1 -10
- optimum/rbln/ops/__init__.py +1 -10
- optimum/rbln/ops/attn.py +5 -14
- optimum/rbln/ops/flash_attn.py +5 -14
- optimum/rbln/ops/kv_cache_update.py +1 -10
- optimum/rbln/transformers/__init__.py +3 -12
- optimum/rbln/transformers/modeling_alias.py +1 -14
- optimum/rbln/transformers/modeling_generic.py +40 -21
- optimum/rbln/transformers/modeling_rope_utils.py +28 -0
- optimum/rbln/transformers/models/__init__.py +3 -12
- optimum/rbln/transformers/models/auto/__init__.py +1 -10
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
- optimum/rbln/transformers/models/bart/__init__.py +1 -10
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
- optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
- optimum/rbln/transformers/models/bert/__init__.py +2 -11
- optimum/rbln/transformers/models/bert/modeling_bert.py +19 -13
- optimum/rbln/transformers/models/clip/__init__.py +1 -10
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +48 -67
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
- optimum/rbln/transformers/models/dpt/__init__.py +1 -10
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
- optimum/rbln/transformers/models/exaone/__init__.py +1 -10
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
- optimum/rbln/transformers/models/gemma/__init__.py +1 -10
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
- optimum/rbln/transformers/models/llama/__init__.py +1 -10
- optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
- optimum/rbln/transformers/models/midm/__init__.py +1 -10
- optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
- optimum/rbln/transformers/models/mistral/__init__.py +1 -10
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
- optimum/rbln/transformers/models/phi/__init__.py +1 -10
- optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -10
- optimum/rbln/transformers/models/t5/__init__.py +1 -10
- optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -10
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
- optimum/rbln/transformers/models/whisper/__init__.py +1 -10
- optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
- optimum/rbln/transformers/utils/rbln_quantization.py +6 -10
- optimum/rbln/utils/__init__.py +1 -10
- optimum/rbln/utils/decorator_utils.py +1 -10
- optimum/rbln/utils/hub.py +1 -10
- optimum/rbln/utils/import_utils.py +1 -10
- optimum/rbln/utils/logging.py +1 -10
- optimum/rbln/utils/model_utils.py +1 -10
- optimum/rbln/utils/runtime_utils.py +1 -10
- optimum/rbln/utils/save_utils.py +2 -10
- optimum/rbln/utils/submodule.py +1 -10
- {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a1.dist-info}/METADATA +6 -4
- optimum_rbln-0.2.1a1.dist-info/RECORD +114 -0
- optimum_rbln-0.2.1a1.dist-info/licenses/LICENSE +201 -0
- optimum_rbln-0.2.1a0.dist-info/RECORD +0 -114
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +0 -288
- {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a1.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
17
|
|
@@ -67,23 +58,33 @@ class RBLNBartModel(RBLNModel):
|
|
67
58
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
68
59
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
69
60
|
|
61
|
+
signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
|
62
|
+
|
70
63
|
if rbln_model_input_names is None:
|
71
64
|
for tokenizer in preprocessors:
|
72
65
|
if hasattr(tokenizer, "model_input_names"):
|
73
|
-
rbln_model_input_names = tokenizer.model_input_names
|
66
|
+
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
74
67
|
# BartModel's forward() does not take token_type_ids as input.
|
75
68
|
# (Added because some of the tokenizers includes 'token_type_ids')
|
76
69
|
if "token_type_ids" in rbln_model_input_names:
|
77
70
|
rbln_model_input_names.remove("token_type_ids")
|
71
|
+
|
72
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
73
|
+
if invalid_params:
|
74
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
78
75
|
break
|
79
76
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
80
77
|
rbln_model_input_names = cls.rbln_model_input_names
|
81
78
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
82
|
-
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
83
79
|
raise ValueError(
|
84
80
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
85
|
-
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(
|
81
|
+
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
|
86
82
|
)
|
83
|
+
else:
|
84
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
85
|
+
if invalid_params:
|
86
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
87
|
+
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
87
88
|
|
88
89
|
if rbln_batch_size is None:
|
89
90
|
rbln_batch_size = 1
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
|
-
from .modeling_bert import RBLNBertModel
|
15
|
+
from .modeling_bert import RBLNBertForQuestionAnswering, RBLNBertModel
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
import logging
|
26
17
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
@@ -29,6 +20,7 @@ from transformers import PretrainedConfig
|
|
29
20
|
|
30
21
|
from ....modeling import RBLNModel
|
31
22
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
23
|
+
from ...modeling_generic import RBLNModelForQuestionAnswering
|
32
24
|
|
33
25
|
|
34
26
|
logger = logging.getLogger(__name__)
|
@@ -64,19 +56,29 @@ class RBLNBertModel(RBLNModel):
|
|
64
56
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
65
57
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
66
58
|
|
59
|
+
signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
|
60
|
+
|
67
61
|
if rbln_model_input_names is None:
|
68
62
|
for tokenizer in preprocessors:
|
69
63
|
if hasattr(tokenizer, "model_input_names"):
|
70
|
-
rbln_model_input_names = tokenizer.model_input_names
|
64
|
+
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
65
|
+
|
66
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
67
|
+
if invalid_params:
|
68
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
71
69
|
break
|
72
70
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
73
71
|
rbln_model_input_names = cls.rbln_model_input_names
|
74
72
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
75
|
-
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
76
73
|
raise ValueError(
|
77
74
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
78
|
-
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(
|
75
|
+
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
|
79
76
|
)
|
77
|
+
else:
|
78
|
+
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
79
|
+
if invalid_params:
|
80
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
81
|
+
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
80
82
|
|
81
83
|
if rbln_batch_size is None:
|
82
84
|
rbln_batch_size = 1
|
@@ -96,3 +98,7 @@ class RBLNBertModel(RBLNModel):
|
|
96
98
|
|
97
99
|
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
98
100
|
return rbln_config
|
101
|
+
|
102
|
+
|
103
|
+
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
104
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import logging
|
25
16
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import math
|
25
16
|
from typing import List, Optional, Tuple
|
26
17
|
|
@@ -631,33 +622,21 @@ class DecoderOnlyAttention(nn.Module):
|
|
631
622
|
if batch_size > 1 and self.phase == "prefill":
|
632
623
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
633
624
|
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
batch_position=b if self.phase == "decode" else batch_position,
|
650
|
-
seq_position=seq_position,
|
651
|
-
scale=self.scale,
|
652
|
-
)
|
653
|
-
_key_states.append(key_state)
|
654
|
-
_value_states.append(value_state)
|
655
|
-
_attn_outputs.append(attn_output)
|
656
|
-
key_states = torch.cat(_key_states, dim=0)
|
657
|
-
value_states = torch.cat(_value_states, dim=0)
|
658
|
-
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
659
|
-
|
660
|
-
attn_outputs = self.o_proj(attn_outputs)
|
625
|
+
attn_output, key_state, value_state = self.attention(
|
626
|
+
query_states,
|
627
|
+
key_states,
|
628
|
+
value_states,
|
629
|
+
attention_mask,
|
630
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
631
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
632
|
+
batch_position=None if self.phase == "decode" else batch_position,
|
633
|
+
seq_position=seq_positions,
|
634
|
+
scale=self.scale,
|
635
|
+
)
|
636
|
+
key_states = key_state
|
637
|
+
value_states = value_state
|
638
|
+
|
639
|
+
attn_outputs = self.o_proj(attn_output)
|
661
640
|
past_key_values[self.layer_idx] = key_states, value_states
|
662
641
|
return attn_outputs, past_key_values
|
663
642
|
|
@@ -703,8 +682,13 @@ class AttentionOp(nn.Module):
|
|
703
682
|
value_state = value_state.unsqueeze(2)
|
704
683
|
attn_mask = attn_mask.unsqueeze(2)
|
705
684
|
|
685
|
+
if self.phase == "decode":
|
686
|
+
batch_size = key_state.shape[0]
|
687
|
+
else:
|
688
|
+
batch_size = 1
|
689
|
+
|
706
690
|
query_state = query_state.view(
|
707
|
-
|
691
|
+
batch_size,
|
708
692
|
self.num_key_value_heads,
|
709
693
|
self.num_heads // self.num_key_value_heads,
|
710
694
|
-1, # seq len
|
@@ -736,9 +720,9 @@ class AttentionOp(nn.Module):
|
|
736
720
|
scale,
|
737
721
|
)
|
738
722
|
|
739
|
-
attn_output = attn_output.view(
|
723
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
740
724
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
741
|
-
attn_output = attn_output.reshape(
|
725
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
742
726
|
|
743
727
|
return attn_output, key_state.squeeze(2), value_state.squeeze(2)
|
744
728
|
|
@@ -867,31 +851,23 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
867
851
|
if cos is not None and sin is not None:
|
868
852
|
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
869
853
|
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
)
|
886
|
-
_key_states.append(key_state)
|
887
|
-
_value_states.append(value_state)
|
888
|
-
_attn_outputs.append(attn_output)
|
889
|
-
key_states = torch.cat(_key_states, dim=0)
|
890
|
-
value_states = torch.cat(_value_states, dim=0)
|
891
|
-
attn_outputs = torch.cat(_attn_outputs, dim=0)
|
892
|
-
|
893
|
-
attn_outputs = self.o_proj(attn_outputs)
|
854
|
+
attn_output, key_state, value_state = self.attention(
|
855
|
+
query_states,
|
856
|
+
key_states,
|
857
|
+
value_states,
|
858
|
+
attention_mask,
|
859
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
860
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
861
|
+
batch_position=None if self.phase == "decode" else batch_position,
|
862
|
+
seq_position=seq_positions,
|
863
|
+
scale=self.scale,
|
864
|
+
)
|
865
|
+
key_states = key_state
|
866
|
+
value_states = value_state
|
867
|
+
|
868
|
+
attn_outputs = self.o_proj(attn_output)
|
894
869
|
past_key_values[self.layer_idx] = key_states, value_states
|
870
|
+
|
895
871
|
return attn_outputs, past_key_values
|
896
872
|
|
897
873
|
|
@@ -917,8 +893,13 @@ class FlashAttentionOp(AttentionOp):
|
|
917
893
|
value_state = value_state.unsqueeze(2)
|
918
894
|
attn_mask = attn_mask.unsqueeze(2)
|
919
895
|
|
896
|
+
if self.phase == "decode":
|
897
|
+
batch_size = key_state.shape[0]
|
898
|
+
else:
|
899
|
+
batch_size = 1
|
900
|
+
|
920
901
|
query_state = query_state.view(
|
921
|
-
|
902
|
+
batch_size,
|
922
903
|
self.num_key_value_heads,
|
923
904
|
self.num_heads // self.num_key_value_heads,
|
924
905
|
-1, # seq len
|
@@ -952,8 +933,8 @@ class FlashAttentionOp(AttentionOp):
|
|
952
933
|
)
|
953
934
|
|
954
935
|
# reshape for removing repeat_kv
|
955
|
-
attn_output = attn_output.view(
|
936
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
956
937
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
957
|
-
attn_output = attn_output.reshape(
|
938
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
958
939
|
|
959
940
|
return attn_output, key_state, value_state
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import inspect
|
25
16
|
from dataclasses import dataclass
|
26
17
|
from pathlib import Path
|
@@ -218,14 +209,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
218
209
|
|
219
210
|
@classmethod
|
220
211
|
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
212
|
+
logger.debug("Loading the LLM model to the CPU.") # TODO(jongho): Remove.
|
213
|
+
|
221
214
|
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
222
215
|
rbln_quantization = rbln_kwargs.get("quantization", None)
|
223
|
-
|
224
216
|
if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
|
225
217
|
model = cls.get_quantized_model(*args, **kwargs)
|
226
218
|
else:
|
227
219
|
model = super().get_pytorch_model(*args, **kwargs)
|
228
220
|
|
221
|
+
logger.debug("Loaded the LLM model to the CPU.")
|
229
222
|
return model
|
230
223
|
|
231
224
|
@classmethod
|
@@ -297,8 +290,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
297
290
|
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
298
291
|
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
299
292
|
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
293
|
+
rbln_prefill_chunk_size = rbln_kwargs.get("prefill_chunk_size", None)
|
294
|
+
|
295
|
+
if rbln_prefill_chunk_size is None:
|
296
|
+
rbln_prefill_chunk_size = 128
|
297
|
+
elif rbln_prefill_chunk_size % 64 != 0 or rbln_prefill_chunk_size == 0:
|
298
|
+
raise ValueError(
|
299
|
+
f"Invalid rbln_prefill_chunk_size: {rbln_prefill_chunk_size}. It must be a nonzero multiple of 64."
|
300
|
+
)
|
300
301
|
|
301
|
-
prefill_chunk_size = 128
|
302
302
|
if rbln_max_seq_len is None:
|
303
303
|
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
304
304
|
model_config, "n_positions", None
|
@@ -369,7 +369,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
369
369
|
|
370
370
|
prefill_input_info = get_input_info(
|
371
371
|
batch_size=1,
|
372
|
-
query_length=
|
372
|
+
query_length=rbln_prefill_chunk_size,
|
373
373
|
use_inputs_embeds=rbln_use_inputs_embeds,
|
374
374
|
hidden_size=hidden_size,
|
375
375
|
)
|
@@ -393,7 +393,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
393
393
|
{
|
394
394
|
"max_seq_len": rbln_max_seq_len,
|
395
395
|
"batch_size": rbln_batch_size,
|
396
|
-
"prefill_chunk_size":
|
396
|
+
"prefill_chunk_size": rbln_prefill_chunk_size,
|
397
397
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
398
398
|
"kvcache_partition_len": rbln_kvcache_partition_len,
|
399
399
|
"attn_impl": rbln_attn_impl,
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_dpt import RBLNDPTForDepthEstimation
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import logging
|
25
16
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
import os
|
25
16
|
from os import environ
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import TYPE_CHECKING
|
25
16
|
|
26
17
|
import torch.nn as nn
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
|
25
16
|
from transformers import AutoModelForCausalLM
|
26
17
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_gemma import RBLNGemmaForCausalLM
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from typing import TYPE_CHECKING
|
25
16
|
|
26
17
|
from ...models.decoderonly.decoderonly_architecture import (
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,15 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from ....utils import logging
|
25
16
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
17
|
from .gemma_architecture import GemmaWrapper
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
2
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,13 +12,4 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
# Portions of this software are licensed under the Apache License,
|
16
|
-
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
-
# additional information regarding copyright ownership.
|
18
|
-
|
19
|
-
# All other portions of this software, including proprietary code,
|
20
|
-
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
-
# copied, modified, or distributed without prior written permission
|
22
|
-
# from Rebellions Inc.
|
23
|
-
|
24
15
|
from .modeling_gpt2 import RBLNGPT2LMHeadModel
|