optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +156 -36
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +85 -75
- optimum/rbln/transformers/__init__.py +79 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +96 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -32,29 +32,60 @@ _import_structure = {
|
|
32
32
|
"RBLNAutoModelForSpeechSeq2Seq",
|
33
33
|
"RBLNAutoModelForVision2Seq",
|
34
34
|
],
|
35
|
-
"bart": [
|
36
|
-
|
35
|
+
"bart": [
|
36
|
+
"RBLNBartForConditionalGeneration",
|
37
|
+
"RBLNBartModel",
|
38
|
+
"RBLNBartForConditionalGenerationConfig",
|
39
|
+
"RBLNBartModelConfig",
|
40
|
+
],
|
41
|
+
"bert": [
|
42
|
+
"RBLNBertModel",
|
43
|
+
"RBLNBertModelConfig",
|
44
|
+
"RBLNBertForQuestionAnswering",
|
45
|
+
"RBLNBertForQuestionAnsweringConfig",
|
46
|
+
"RBLNBertForMaskedLM",
|
47
|
+
"RBLNBertForMaskedLMConfig",
|
48
|
+
],
|
37
49
|
"clip": [
|
38
50
|
"RBLNCLIPTextModel",
|
51
|
+
"RBLNCLIPTextModelConfig",
|
39
52
|
"RBLNCLIPTextModelWithProjection",
|
53
|
+
"RBLNCLIPTextModelWithProjectionConfig",
|
40
54
|
"RBLNCLIPVisionModel",
|
55
|
+
"RBLNCLIPVisionModelConfig",
|
41
56
|
"RBLNCLIPVisionModelWithProjection",
|
57
|
+
"RBLNCLIPVisionModelWithProjectionConfig",
|
58
|
+
],
|
59
|
+
"decoderonly": [
|
60
|
+
"RBLNDecoderOnlyModelForCausalLM",
|
61
|
+
"RBLNDecoderOnlyModelForCausalLMConfig",
|
62
|
+
],
|
63
|
+
"dpt": [
|
64
|
+
"RBLNDPTForDepthEstimation",
|
65
|
+
"RBLNDPTForDepthEstimationConfig",
|
66
|
+
],
|
67
|
+
"exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
|
68
|
+
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
|
69
|
+
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
|
70
|
+
"llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig"],
|
71
|
+
"llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
|
72
|
+
"midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
|
73
|
+
"mistral": ["RBLNMistralForCausalLM", "RBLNMistralForCausalLMConfig"],
|
74
|
+
"phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig"],
|
75
|
+
"qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig"],
|
76
|
+
"time_series_transformers": [
|
77
|
+
"RBLNTimeSeriesTransformerForPrediction",
|
78
|
+
"RBLNTimeSeriesTransformerForPredictionConfig",
|
79
|
+
],
|
80
|
+
"t5": [
|
81
|
+
"RBLNT5EncoderModel",
|
82
|
+
"RBLNT5ForConditionalGeneration",
|
83
|
+
"RBLNT5EncoderModelConfig",
|
84
|
+
"RBLNT5ForConditionalGenerationConfig",
|
42
85
|
],
|
43
|
-
"
|
44
|
-
"
|
45
|
-
"
|
46
|
-
"gpt2": ["RBLNGPT2LMHeadModel"],
|
47
|
-
"llama": ["RBLNLlamaForCausalLM"],
|
48
|
-
"llava_next": ["RBLNLlavaNextForConditionalGeneration"],
|
49
|
-
"midm": ["RBLNMidmLMHeadModel"],
|
50
|
-
"mistral": ["RBLNMistralForCausalLM"],
|
51
|
-
"phi": ["RBLNPhiForCausalLM"],
|
52
|
-
"qwen2": ["RBLNQwen2ForCausalLM"],
|
53
|
-
"time_series_transformers": ["RBLNTimeSeriesTransformerForPrediction"],
|
54
|
-
"t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
|
55
|
-
"wav2vec2": ["RBLNWav2Vec2ForCTC"],
|
56
|
-
"whisper": ["RBLNWhisperForConditionalGeneration"],
|
57
|
-
"xlm_roberta": ["RBLNXLMRobertaModel"],
|
86
|
+
"wav2vec2": ["RBLNWav2Vec2ForCTC", "RBLNWav2Vec2ForCTCConfig"],
|
87
|
+
"whisper": ["RBLNWhisperForConditionalGeneration", "RBLNWhisperForConditionalGenerationConfig"],
|
88
|
+
"xlm_roberta": ["RBLNXLMRobertaModel", "RBLNXLMRobertaModelConfig"],
|
58
89
|
}
|
59
90
|
|
60
91
|
if TYPE_CHECKING:
|
@@ -72,29 +103,60 @@ if TYPE_CHECKING:
|
|
72
103
|
RBLNAutoModelForSpeechSeq2Seq,
|
73
104
|
RBLNAutoModelForVision2Seq,
|
74
105
|
)
|
75
|
-
from .bart import
|
76
|
-
|
106
|
+
from .bart import (
|
107
|
+
RBLNBartForConditionalGeneration,
|
108
|
+
RBLNBartForConditionalGenerationConfig,
|
109
|
+
RBLNBartModel,
|
110
|
+
RBLNBartModelConfig,
|
111
|
+
)
|
112
|
+
from .bert import (
|
113
|
+
RBLNBertForMaskedLM,
|
114
|
+
RBLNBertForMaskedLMConfig,
|
115
|
+
RBLNBertForQuestionAnswering,
|
116
|
+
RBLNBertForQuestionAnsweringConfig,
|
117
|
+
RBLNBertModel,
|
118
|
+
RBLNBertModelConfig,
|
119
|
+
)
|
77
120
|
from .clip import (
|
78
121
|
RBLNCLIPTextModel,
|
122
|
+
RBLNCLIPTextModelConfig,
|
79
123
|
RBLNCLIPTextModelWithProjection,
|
124
|
+
RBLNCLIPTextModelWithProjectionConfig,
|
80
125
|
RBLNCLIPVisionModel,
|
126
|
+
RBLNCLIPVisionModelConfig,
|
81
127
|
RBLNCLIPVisionModelWithProjection,
|
128
|
+
RBLNCLIPVisionModelWithProjectionConfig,
|
129
|
+
)
|
130
|
+
from .decoderonly import (
|
131
|
+
RBLNDecoderOnlyModelForCausalLM,
|
132
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
133
|
+
)
|
134
|
+
from .dpt import (
|
135
|
+
RBLNDPTForDepthEstimation,
|
136
|
+
RBLNDPTForDepthEstimationConfig,
|
137
|
+
)
|
138
|
+
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
139
|
+
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
|
140
|
+
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
|
141
|
+
from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
|
142
|
+
from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
|
143
|
+
from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
|
144
|
+
from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
|
145
|
+
from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig
|
146
|
+
from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig
|
147
|
+
from .t5 import (
|
148
|
+
RBLNT5EncoderModel,
|
149
|
+
RBLNT5EncoderModelConfig,
|
150
|
+
RBLNT5ForConditionalGeneration,
|
151
|
+
RBLNT5ForConditionalGenerationConfig,
|
152
|
+
)
|
153
|
+
from .time_series_transformers import (
|
154
|
+
RBLNTimeSeriesTransformerForPrediction,
|
155
|
+
RBLNTimeSeriesTransformerForPredictionConfig,
|
82
156
|
)
|
83
|
-
from .
|
84
|
-
from .
|
85
|
-
from .
|
86
|
-
from .gpt2 import RBLNGPT2LMHeadModel
|
87
|
-
from .llama import RBLNLlamaForCausalLM
|
88
|
-
from .llava_next import RBLNLlavaNextForConditionalGeneration
|
89
|
-
from .midm import RBLNMidmLMHeadModel
|
90
|
-
from .mistral import RBLNMistralForCausalLM
|
91
|
-
from .phi import RBLNPhiForCausalLM
|
92
|
-
from .qwen2 import RBLNQwen2ForCausalLM
|
93
|
-
from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
94
|
-
from .time_series_transformers import RBLNTimeSeriesTransformerForPrediction
|
95
|
-
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
96
|
-
from .whisper import RBLNWhisperForConditionalGeneration
|
97
|
-
from .xlm_roberta import RBLNXLMRobertaModel
|
157
|
+
from .wav2vec2 import RBLNWav2Vec2ForCTC, RBLNWav2Vec2ForCTCConfig
|
158
|
+
from .whisper import RBLNWhisperForConditionalGeneration, RBLNWhisperForConditionalGenerationConfig
|
159
|
+
from .xlm_roberta import RBLNXLMRobertaModel, RBLNXLMRobertaModelConfig
|
98
160
|
|
99
161
|
else:
|
100
162
|
import sys
|
@@ -20,8 +20,8 @@ from transformers import AutoConfig, PretrainedConfig
|
|
20
20
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
21
21
|
from transformers.models.auto.auto_factory import _get_model_class
|
22
22
|
|
23
|
+
from optimum.rbln.configuration_utils import RBLNAutoConfig
|
23
24
|
from optimum.rbln.modeling_base import RBLNBaseModel
|
24
|
-
from optimum.rbln.modeling_config import RBLNConfig
|
25
25
|
from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
|
26
26
|
|
27
27
|
|
@@ -154,9 +154,9 @@ class _BaseAutoModelClass:
|
|
154
154
|
model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
|
155
155
|
model_id=pretrained_model_name_or_path, **filtered_kwargs
|
156
156
|
)
|
157
|
-
rbln_config =
|
157
|
+
rbln_config = RBLNAutoConfig.load(model_path_subfolder)
|
158
158
|
|
159
|
-
return rbln_config.
|
159
|
+
return rbln_config.rbln_model_cls_name
|
160
160
|
|
161
161
|
@classmethod
|
162
162
|
def from_pretrained(
|
@@ -13,4 +13,5 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from ....ops import paged_attn_decode, paged_causal_attn_decode
|
16
|
+
from .configuration_bart import RBLNBartForConditionalGenerationConfig, RBLNBartModelConfig
|
16
17
|
from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
|
16
|
+
from ..seq2seq import RBLNModelForSeq2SeqLMConfig
|
17
|
+
|
18
|
+
|
19
|
+
class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
|
24
|
+
pass
|
@@ -13,110 +13,36 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Callable
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
17
17
|
|
18
|
-
from transformers import BartForConditionalGeneration,
|
18
|
+
from transformers import BartForConditionalGeneration, PreTrainedModel
|
19
19
|
|
20
|
-
from ....modeling import RBLNModel
|
21
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
22
20
|
from ....utils.logging import get_logger
|
21
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
23
22
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
24
23
|
from .bart_architecture import BartWrapper
|
24
|
+
from .configuration_bart import RBLNBartForConditionalGenerationConfig
|
25
25
|
|
26
26
|
|
27
27
|
logger = get_logger()
|
28
28
|
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
|
-
from transformers import
|
31
|
+
from transformers import PreTrainedModel
|
32
32
|
|
33
33
|
|
34
|
-
class RBLNBartModel(
|
35
|
-
|
36
|
-
def _get_rbln_config(
|
37
|
-
cls,
|
38
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
39
|
-
model_config: Optional["PretrainedConfig"] = None,
|
40
|
-
rbln_kwargs: Dict[str, Any] = {},
|
41
|
-
) -> RBLNConfig:
|
42
|
-
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
43
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
44
|
-
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
45
|
-
|
46
|
-
max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
|
47
|
-
|
48
|
-
if rbln_max_seq_len is None:
|
49
|
-
rbln_max_seq_len = max_position_embeddings
|
50
|
-
if rbln_max_seq_len is None:
|
51
|
-
for tokenizer in preprocessors:
|
52
|
-
if hasattr(tokenizer, "model_max_length"):
|
53
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
54
|
-
break
|
55
|
-
if rbln_max_seq_len is None:
|
56
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
57
|
-
|
58
|
-
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
59
|
-
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
60
|
-
|
61
|
-
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
62
|
-
|
63
|
-
if rbln_model_input_names is None:
|
64
|
-
for tokenizer in preprocessors:
|
65
|
-
if hasattr(tokenizer, "model_input_names"):
|
66
|
-
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
67
|
-
# BartModel's forward() does not take token_type_ids as input.
|
68
|
-
# (Added because some of the tokenizers includes 'token_type_ids')
|
69
|
-
if "token_type_ids" in rbln_model_input_names:
|
70
|
-
rbln_model_input_names.remove("token_type_ids")
|
71
|
-
|
72
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
73
|
-
if invalid_params:
|
74
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
75
|
-
break
|
76
|
-
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
77
|
-
rbln_model_input_names = cls.rbln_model_input_names
|
78
|
-
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
79
|
-
raise ValueError(
|
80
|
-
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
81
|
-
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
|
82
|
-
)
|
83
|
-
else:
|
84
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
85
|
-
if invalid_params:
|
86
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
87
|
-
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
88
|
-
|
89
|
-
if rbln_batch_size is None:
|
90
|
-
rbln_batch_size = 1
|
91
|
-
|
92
|
-
input_info = [
|
93
|
-
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
94
|
-
for model_input_name in rbln_model_input_names
|
95
|
-
]
|
96
|
-
|
97
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
98
|
-
|
99
|
-
rbln_config = RBLNConfig(
|
100
|
-
rbln_cls=cls.__name__,
|
101
|
-
compile_cfgs=[rbln_compile_config],
|
102
|
-
rbln_kwargs=rbln_kwargs,
|
103
|
-
)
|
104
|
-
|
105
|
-
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
106
|
-
return rbln_config
|
34
|
+
class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
|
35
|
+
pass
|
107
36
|
|
108
37
|
|
109
38
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
110
39
|
support_causal_attn = True
|
111
40
|
|
112
41
|
@classmethod
|
113
|
-
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config:
|
114
|
-
|
115
|
-
rbln_config.
|
42
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNBartForConditionalGenerationConfig):
|
43
|
+
return BartWrapper(
|
44
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
|
116
45
|
)
|
117
|
-
use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
|
118
|
-
|
119
|
-
return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
|
120
46
|
|
121
47
|
def __getattr__(self, __name: str) -> Any:
|
122
48
|
def redirect(func):
|
@@ -12,4 +12,5 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from .configuration_bert import RBLNBertForMaskedLMConfig, RBLNBertForQuestionAnsweringConfig, RBLNBertModelConfig
|
15
16
|
from .modeling_bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from ...configuration_generic import (
|
16
|
+
RBLNModelForMaskedLMConfig,
|
17
|
+
RBLNModelForQuestionAnsweringConfig,
|
18
|
+
RBLNTransformerEncoderForFeatureExtractionConfig,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
23
|
+
pass
|
24
|
+
|
25
|
+
|
26
|
+
class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
|
31
|
+
pass
|
@@ -12,92 +12,19 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
import inspect
|
16
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
17
|
-
|
18
|
-
from transformers import PretrainedConfig
|
19
|
-
|
20
|
-
from ....modeling import RBLNModel
|
21
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
22
15
|
from ....utils.logging import get_logger
|
23
|
-
from ...modeling_generic import
|
16
|
+
from ...modeling_generic import (
|
17
|
+
RBLNModelForMaskedLM,
|
18
|
+
RBLNModelForQuestionAnswering,
|
19
|
+
RBLNTransformerEncoderForFeatureExtraction,
|
20
|
+
)
|
24
21
|
|
25
22
|
|
26
23
|
logger = get_logger(__name__)
|
27
24
|
|
28
|
-
if TYPE_CHECKING:
|
29
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
30
|
-
|
31
|
-
|
32
|
-
class RBLNBertModel(RBLNModel):
|
33
|
-
@classmethod
|
34
|
-
def _get_rbln_config(
|
35
|
-
cls,
|
36
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
37
|
-
model_config: Optional["PretrainedConfig"] = None,
|
38
|
-
rbln_kwargs: Dict[str, Any] = {},
|
39
|
-
) -> RBLNConfig:
|
40
|
-
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
41
|
-
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
42
|
-
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
43
|
-
|
44
|
-
max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
|
45
|
-
|
46
|
-
if rbln_max_seq_len is None:
|
47
|
-
rbln_max_seq_len = max_position_embeddings
|
48
|
-
if rbln_max_seq_len is None:
|
49
|
-
for tokenizer in preprocessors:
|
50
|
-
if hasattr(tokenizer, "model_max_length"):
|
51
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
52
|
-
break
|
53
|
-
if rbln_max_seq_len is None:
|
54
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
55
|
-
|
56
|
-
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
57
|
-
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
58
|
-
|
59
|
-
signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
|
60
|
-
|
61
|
-
if rbln_model_input_names is None:
|
62
|
-
for tokenizer in preprocessors:
|
63
|
-
if hasattr(tokenizer, "model_input_names"):
|
64
|
-
rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
|
65
|
-
|
66
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
67
|
-
if invalid_params:
|
68
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
69
|
-
break
|
70
|
-
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
71
|
-
rbln_model_input_names = cls.rbln_model_input_names
|
72
|
-
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
73
|
-
raise ValueError(
|
74
|
-
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
75
|
-
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
|
76
|
-
)
|
77
|
-
else:
|
78
|
-
invalid_params = set(rbln_model_input_names) - set(signature_params)
|
79
|
-
if invalid_params:
|
80
|
-
raise ValueError(f"Invalid model input names: {invalid_params}")
|
81
|
-
rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
|
82
|
-
|
83
|
-
if rbln_batch_size is None:
|
84
|
-
rbln_batch_size = 1
|
85
|
-
|
86
|
-
input_info = [
|
87
|
-
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
88
|
-
for model_input_name in rbln_model_input_names
|
89
|
-
]
|
90
|
-
|
91
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
92
|
-
|
93
|
-
rbln_config = RBLNConfig(
|
94
|
-
rbln_cls=cls.__name__,
|
95
|
-
compile_cfgs=[rbln_compile_config],
|
96
|
-
rbln_kwargs=rbln_kwargs,
|
97
|
-
)
|
98
25
|
|
99
|
-
|
100
|
-
|
26
|
+
class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
27
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
101
28
|
|
102
29
|
|
103
30
|
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
@@ -12,6 +12,12 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from .configuration_clip import (
|
16
|
+
RBLNCLIPTextModelConfig,
|
17
|
+
RBLNCLIPTextModelWithProjectionConfig,
|
18
|
+
RBLNCLIPVisionModelConfig,
|
19
|
+
RBLNCLIPVisionModelWithProjectionConfig,
|
20
|
+
)
|
15
21
|
from .modeling_clip import (
|
16
22
|
RBLNCLIPTextModel,
|
17
23
|
RBLNCLIPTextModelWithProjection,
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class RBLNCLIPTextModelConfig(RBLNModelConfig):
|
21
|
+
def __init__(self, batch_size: Optional[int] = None, **kwargs):
|
22
|
+
"""
|
23
|
+
Args:
|
24
|
+
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
25
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
26
|
+
|
27
|
+
Raises:
|
28
|
+
ValueError: If batch_size is not a positive integer.
|
29
|
+
"""
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
self.batch_size = batch_size or 1
|
32
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
33
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
34
|
+
|
35
|
+
|
36
|
+
class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNCLIPVisionModelConfig(RBLNModelConfig):
|
41
|
+
def __init__(self, batch_size: Optional[int] = None, image_size: Optional[int] = None, **kwargs):
|
42
|
+
"""
|
43
|
+
Args:
|
44
|
+
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
45
|
+
image_size (Optional[int]): The size of input images. Can be an integer for square images,
|
46
|
+
a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
|
47
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
48
|
+
|
49
|
+
Raises:
|
50
|
+
ValueError: If batch_size is not a positive integer.
|
51
|
+
"""
|
52
|
+
super().__init__(**kwargs)
|
53
|
+
self.batch_size = batch_size or 1
|
54
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
55
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
56
|
+
|
57
|
+
self.image_size = image_size
|
58
|
+
|
59
|
+
@property
|
60
|
+
def image_width(self):
|
61
|
+
if isinstance(self.image_size, int):
|
62
|
+
return self.image_size
|
63
|
+
elif isinstance(self.image_size, (list, tuple)):
|
64
|
+
return self.image_size[1]
|
65
|
+
else:
|
66
|
+
return self.image_size["width"]
|
67
|
+
|
68
|
+
@property
|
69
|
+
def image_height(self):
|
70
|
+
if isinstance(self.image_size, int):
|
71
|
+
return self.image_size
|
72
|
+
elif isinstance(self.image_size, (list, tuple)):
|
73
|
+
return self.image_size[0]
|
74
|
+
else:
|
75
|
+
return self.image_size["height"]
|
76
|
+
|
77
|
+
|
78
|
+
class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
|
79
|
+
pass
|