optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,35 +21,84 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
from typing import TYPE_CHECKING
|
24
25
|
|
25
|
-
from .
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
26
|
+
from transformers.utils import _LazyModule
|
27
|
+
|
28
|
+
|
29
|
+
_import_structure = {
|
30
|
+
"auto": [
|
31
|
+
"RBLNAutoModel",
|
32
|
+
"RBLNAutoModelForAudioClassification",
|
33
|
+
"RBLNAutoModelForCausalLM",
|
34
|
+
"RBLNAutoModelForCTC",
|
35
|
+
"RBLNAutoModelForDepthEstimation",
|
36
|
+
"RBLNAutoModelForImageClassification",
|
37
|
+
"RBLNAutoModelForMaskedLM",
|
38
|
+
"RBLNAutoModelForQuestionAnswering",
|
39
|
+
"RBLNAutoModelForSeq2SeqLM",
|
40
|
+
"RBLNAutoModelForSequenceClassification",
|
41
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
42
|
+
"RBLNAutoModelForVision2Seq",
|
43
|
+
],
|
44
|
+
"bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
|
45
|
+
"bert": ["RBLNBertModel"],
|
46
|
+
"clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
|
47
|
+
"dpt": ["RBLNDPTForDepthEstimation"],
|
48
|
+
"exaone": ["RBLNExaoneForCausalLM"],
|
49
|
+
"gemma": ["RBLNGemmaForCausalLM"],
|
50
|
+
"gpt2": ["RBLNGPT2LMHeadModel"],
|
51
|
+
"llama": ["RBLNLlamaForCausalLM"],
|
52
|
+
"llava_next": ["RBLNLlavaNextForConditionalGeneration"],
|
53
|
+
"midm": ["RBLNMidmLMHeadModel"],
|
54
|
+
"mistral": ["RBLNMistralForCausalLM"],
|
55
|
+
"phi": ["RBLNPhiForCausalLM"],
|
56
|
+
"qwen2": ["RBLNQwen2ForCausalLM"],
|
57
|
+
"t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
|
58
|
+
"wav2vec2": ["RBLNWav2Vec2ForCTC"],
|
59
|
+
"whisper": ["RBLNWhisperForConditionalGeneration"],
|
60
|
+
"xlm_roberta": ["RBLNXLMRobertaModel"],
|
61
|
+
}
|
62
|
+
|
63
|
+
if TYPE_CHECKING:
|
64
|
+
from .auto import (
|
65
|
+
RBLNAutoModel,
|
66
|
+
RBLNAutoModelForAudioClassification,
|
67
|
+
RBLNAutoModelForCausalLM,
|
68
|
+
RBLNAutoModelForCTC,
|
69
|
+
RBLNAutoModelForDepthEstimation,
|
70
|
+
RBLNAutoModelForImageClassification,
|
71
|
+
RBLNAutoModelForMaskedLM,
|
72
|
+
RBLNAutoModelForQuestionAnswering,
|
73
|
+
RBLNAutoModelForSeq2SeqLM,
|
74
|
+
RBLNAutoModelForSequenceClassification,
|
75
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
76
|
+
RBLNAutoModelForVision2Seq,
|
77
|
+
)
|
78
|
+
from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
79
|
+
from .bert import RBLNBertModel
|
80
|
+
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
81
|
+
from .dpt import RBLNDPTForDepthEstimation
|
82
|
+
from .exaone import RBLNExaoneForCausalLM
|
83
|
+
from .gemma import RBLNGemmaForCausalLM
|
84
|
+
from .gpt2 import RBLNGPT2LMHeadModel
|
85
|
+
from .llama import RBLNLlamaForCausalLM
|
86
|
+
from .llava_next import RBLNLlavaNextForConditionalGeneration
|
87
|
+
from .midm import RBLNMidmLMHeadModel
|
88
|
+
from .mistral import RBLNMistralForCausalLM
|
89
|
+
from .phi import RBLNPhiForCausalLM
|
90
|
+
from .qwen2 import RBLNQwen2ForCausalLM
|
91
|
+
from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
|
92
|
+
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
93
|
+
from .whisper import RBLNWhisperForConditionalGeneration
|
94
|
+
from .xlm_roberta import RBLNXLMRobertaModel
|
95
|
+
|
96
|
+
else:
|
97
|
+
import sys
|
98
|
+
|
99
|
+
sys.modules[__name__] = _LazyModule(
|
100
|
+
__name__,
|
101
|
+
globals()["__file__"],
|
102
|
+
_import_structure,
|
103
|
+
module_spec=__spec__,
|
104
|
+
)
|
@@ -22,8 +22,16 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import importlib
|
25
|
+
import inspect
|
26
|
+
import warnings
|
25
27
|
|
26
|
-
from transformers import AutoConfig
|
28
|
+
from transformers import AutoConfig, PretrainedConfig
|
29
|
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
30
|
+
from transformers.models.auto.auto_factory import _get_model_class
|
31
|
+
|
32
|
+
from optimum.rbln.modeling_base import RBLNBaseModel
|
33
|
+
from optimum.rbln.modeling_config import RBLNConfig
|
34
|
+
from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
|
27
35
|
|
28
36
|
|
29
37
|
class _BaseAutoModelClass:
|
@@ -33,46 +41,132 @@ class _BaseAutoModelClass:
|
|
33
41
|
def __init__(self, *args, **kwargs):
|
34
42
|
raise EnvironmentError(
|
35
43
|
f"{self.__class__.__name__} is designed to be instantiated "
|
36
|
-
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`
|
37
|
-
f"`{self.__class__.__name__}.from_config(config)` methods."
|
44
|
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`"
|
38
45
|
)
|
39
46
|
|
40
47
|
@classmethod
|
41
48
|
def get_rbln_cls(
|
42
49
|
cls,
|
43
|
-
|
50
|
+
pretrained_model_name_or_path,
|
44
51
|
*args,
|
52
|
+
export=True,
|
45
53
|
**kwargs,
|
46
54
|
):
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
55
|
+
"""
|
56
|
+
Determine the appropriate RBLN model class based on the given model ID and configuration.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
|
60
|
+
export (bool): Whether to infer the class based on Hugging Face (HF) architecture.
|
61
|
+
kwargs: Additional arguments for configuration and loading.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
RBLNBaseModel: The corresponding RBLN model class.
|
65
|
+
"""
|
66
|
+
if export:
|
67
|
+
hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
|
68
|
+
rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
|
69
|
+
else:
|
70
|
+
rbln_class_name = cls.get_rbln_model_class_name(pretrained_model_name_or_path, **kwargs)
|
71
|
+
|
72
|
+
if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
|
73
|
+
raise ValueError(
|
74
|
+
f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
|
75
|
+
"Please use the `from_pretrained()` method of the appropriate class to load this model, "
|
76
|
+
f"or directly use '{rbln_class_name}.from_pretrained()`."
|
77
|
+
)
|
65
78
|
|
66
79
|
try:
|
80
|
+
module = importlib.import_module("optimum.rbln")
|
67
81
|
rbln_cls = getattr(module, rbln_class_name)
|
68
82
|
except AttributeError as e:
|
69
83
|
raise AttributeError(
|
70
|
-
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{
|
84
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
|
71
85
|
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
72
86
|
) from e
|
73
87
|
|
74
88
|
return rbln_cls
|
75
89
|
|
90
|
+
@classmethod
|
91
|
+
def infer_hf_model_class(
|
92
|
+
cls,
|
93
|
+
pretrained_model_name_or_path,
|
94
|
+
*args,
|
95
|
+
**kwargs,
|
96
|
+
):
|
97
|
+
"""
|
98
|
+
Infer the Hugging Face model class based on the configuration or model name.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
|
102
|
+
kwargs: Additional arguments for configuration and loading.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
PretrainedModel: The inferred Hugging Face model class.
|
106
|
+
"""
|
107
|
+
|
108
|
+
# Try to load configuration if provided or retrieve it from the model ID
|
109
|
+
config = kwargs.pop("config", None)
|
110
|
+
kwargs.update({"trust_remote_code": True})
|
111
|
+
kwargs["_from_auto"] = True
|
112
|
+
|
113
|
+
# Load configuration if not already provided
|
114
|
+
if not isinstance(config, PretrainedConfig):
|
115
|
+
config, kwargs = AutoConfig.from_pretrained(
|
116
|
+
pretrained_model_name_or_path,
|
117
|
+
return_unused_kwargs=True,
|
118
|
+
**kwargs,
|
119
|
+
)
|
120
|
+
|
121
|
+
# Get hf_model_class from Config
|
122
|
+
has_remote_code = (
|
123
|
+
hasattr(config, "auto_map") and convert_rbln_to_hf_model_name(cls.__name__) in config.auto_map
|
124
|
+
)
|
125
|
+
if has_remote_code:
|
126
|
+
class_ref = config.auto_map[convert_rbln_to_hf_model_name(cls.__name__)]
|
127
|
+
model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
|
128
|
+
elif type(config) in cls._model_mapping.keys():
|
129
|
+
model_class = _get_model_class(config, cls._model_mapping)
|
130
|
+
else:
|
131
|
+
raise ValueError(
|
132
|
+
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
|
133
|
+
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
|
134
|
+
)
|
135
|
+
|
136
|
+
if model_class.__name__ != config.architectures[0]:
|
137
|
+
warnings.warn(
|
138
|
+
f"`{cls.__name__}.from_pretrained()` is invoking `{convert_hf_to_rbln_model_name(model_class.__name__)}.from_pretrained()`, which does not match the "
|
139
|
+
f"expected architecture `RBLN{config.architectures[0]}` from config. This mismatch could cause some operations to not be properly loaded "
|
140
|
+
f"from the checkpoint, leading to potential unintended behavior. If this is not intentional, consider calling the "
|
141
|
+
f"`from_pretrained()` method directly from the `RBLN{config.architectures[0]}` class instead.",
|
142
|
+
UserWarning,
|
143
|
+
)
|
144
|
+
|
145
|
+
return model_class
|
146
|
+
|
147
|
+
@classmethod
|
148
|
+
def get_rbln_model_class_name(cls, pretrained_model_name_or_path, **kwargs):
|
149
|
+
"""
|
150
|
+
Retrieve the path to the compiled model directory for a given RBLN model.
|
151
|
+
|
152
|
+
Args:
|
153
|
+
pretrained_model_name_or_path (str): Identifier of the model.
|
154
|
+
kwargs: Additional arguments that match the parameters of `_load_compiled_model_dir`.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
str: Path to the compiled model directory.
|
158
|
+
"""
|
159
|
+
sig = inspect.signature(RBLNBaseModel._load_compiled_model_dir)
|
160
|
+
valid_params = sig.parameters.keys()
|
161
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
|
162
|
+
|
163
|
+
model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
|
164
|
+
model_id=pretrained_model_name_or_path, **filtered_kwargs
|
165
|
+
)
|
166
|
+
rbln_config = RBLNConfig.load(model_path_subfolder)
|
167
|
+
|
168
|
+
return rbln_config.meta["cls"]
|
169
|
+
|
76
170
|
@classmethod
|
77
171
|
def from_pretrained(
|
78
172
|
cls,
|
@@ -21,18 +21,31 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
|
24
25
|
from transformers.models.auto.modeling_auto import (
|
26
|
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
|
25
27
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
28
|
+
MODEL_FOR_CAUSAL_LM_MAPPING,
|
26
29
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
30
|
+
MODEL_FOR_CTC_MAPPING,
|
27
31
|
MODEL_FOR_CTC_MAPPING_NAMES,
|
32
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
|
28
33
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
|
34
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
29
35
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
36
|
+
MODEL_FOR_MASKED_LM_MAPPING,
|
30
37
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
38
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
31
39
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
40
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
32
41
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
42
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
33
43
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
44
|
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
34
45
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
46
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
35
47
|
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
48
|
+
MODEL_MAPPING,
|
36
49
|
MODEL_MAPPING_NAMES,
|
37
50
|
)
|
38
51
|
|
@@ -48,48 +61,60 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
|
|
48
61
|
|
49
62
|
|
50
63
|
class RBLNAutoModel(_BaseAutoModelClass):
|
51
|
-
_model_mapping =
|
64
|
+
_model_mapping = MODEL_MAPPING
|
65
|
+
_model_mapping_names = MODEL_MAPPING_NAMES
|
52
66
|
|
53
67
|
|
54
68
|
class RBLNAutoModelForCTC(_BaseAutoModelClass):
|
55
|
-
_model_mapping =
|
69
|
+
_model_mapping = MODEL_FOR_CTC_MAPPING
|
70
|
+
_model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
|
56
71
|
|
57
72
|
|
58
73
|
class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
|
59
|
-
_model_mapping =
|
74
|
+
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
75
|
+
_model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
60
76
|
|
61
77
|
|
62
78
|
class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
63
|
-
_model_mapping =
|
79
|
+
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
80
|
+
_model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
64
81
|
|
65
82
|
|
66
83
|
class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
67
|
-
_model_mapping =
|
84
|
+
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
85
|
+
_model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
68
86
|
|
69
87
|
|
70
88
|
class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
|
71
|
-
_model_mapping =
|
89
|
+
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
90
|
+
_model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
72
91
|
|
73
92
|
|
74
93
|
class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
|
75
|
-
_model_mapping =
|
94
|
+
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
95
|
+
_model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
76
96
|
|
77
97
|
|
78
98
|
class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
79
|
-
_model_mapping =
|
99
|
+
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
100
|
+
_model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
80
101
|
|
81
102
|
|
82
103
|
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
83
|
-
_model_mapping =
|
104
|
+
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
105
|
+
_model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
84
106
|
|
85
107
|
|
86
108
|
class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
|
87
|
-
_model_mapping =
|
109
|
+
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
110
|
+
_model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
88
111
|
|
89
112
|
|
90
113
|
class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
|
91
|
-
_model_mapping =
|
114
|
+
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
115
|
+
_model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
92
116
|
|
93
117
|
|
94
118
|
class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
95
|
-
_model_mapping =
|
119
|
+
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
120
|
+
_model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
@@ -24,9 +24,9 @@
|
|
24
24
|
import inspect
|
25
25
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
26
26
|
|
27
|
-
from transformers import
|
27
|
+
from transformers import BartForConditionalGeneration, PretrainedConfig
|
28
28
|
|
29
|
-
from ....
|
29
|
+
from ....modeling import RBLNModel
|
30
30
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
@@ -41,9 +41,6 @@ if TYPE_CHECKING:
|
|
41
41
|
|
42
42
|
|
43
43
|
class RBLNBartModel(RBLNModel):
|
44
|
-
original_model_class = BartModel
|
45
|
-
original_config_class = BartConfig
|
46
|
-
|
47
44
|
@classmethod
|
48
45
|
def _get_rbln_config(
|
49
46
|
cls,
|
@@ -82,7 +79,7 @@ class RBLNBartModel(RBLNModel):
|
|
82
79
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
83
80
|
rbln_model_input_names = cls.rbln_model_input_names
|
84
81
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
85
|
-
input_names_order = inspect.signature(cls.
|
82
|
+
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
86
83
|
raise ValueError(
|
87
84
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
88
85
|
f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
|
@@ -25,9 +25,9 @@ import inspect
|
|
25
25
|
import logging
|
26
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
|
-
from transformers import
|
28
|
+
from transformers import PretrainedConfig
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
@@ -38,9 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNBertModel(RBLNModel):
|
41
|
-
original_model_class = BertModel
|
42
|
-
original_config_class = BertConfig
|
43
|
-
|
44
41
|
@classmethod
|
45
42
|
def _get_rbln_config(
|
46
43
|
cls,
|
@@ -75,7 +72,7 @@ class RBLNBertModel(RBLNModel):
|
|
75
72
|
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
76
73
|
rbln_model_input_names = cls.rbln_model_input_names
|
77
74
|
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
78
|
-
input_names_order = inspect.signature(cls.
|
75
|
+
input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
|
79
76
|
raise ValueError(
|
80
77
|
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
81
78
|
f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
|
@@ -26,19 +26,17 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
26
26
|
|
27
27
|
import torch
|
28
28
|
from transformers import (
|
29
|
-
AutoConfig,
|
30
|
-
AutoModel,
|
31
29
|
CLIPTextConfig,
|
32
30
|
CLIPTextModel,
|
33
|
-
CLIPTextModelWithProjection,
|
34
31
|
CLIPVisionConfig,
|
35
32
|
CLIPVisionModel,
|
36
33
|
)
|
37
34
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
38
35
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
39
36
|
|
40
|
-
from ....
|
37
|
+
from ....modeling import RBLNModel
|
41
38
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
39
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
42
40
|
|
43
41
|
|
44
42
|
logger = logging.getLogger(__name__)
|
@@ -58,24 +56,14 @@ class _TextEncoder(torch.nn.Module):
|
|
58
56
|
|
59
57
|
|
60
58
|
class RBLNCLIPTextModel(RBLNModel):
|
61
|
-
original_model_class = CLIPTextModel
|
62
|
-
original_config_class = CLIPTextConfig
|
63
|
-
|
64
|
-
@classmethod
|
65
|
-
def from_pretrained(cls, *args, **kwargs):
|
66
|
-
configtmp = AutoConfig.from_pretrained
|
67
|
-
modeltmp = AutoModel.from_pretrained
|
68
|
-
AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
|
69
|
-
AutoModel.from_pretrained = cls.original_model_class.from_pretrained
|
70
|
-
rt = super().from_pretrained(*args, **kwargs)
|
71
|
-
AutoConfig.from_pretrained = configtmp
|
72
|
-
AutoModel.from_pretrained = modeltmp
|
73
|
-
return rt
|
74
|
-
|
75
59
|
@classmethod
|
76
60
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
77
61
|
return _TextEncoder(model).eval()
|
78
62
|
|
63
|
+
@classmethod
|
64
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
65
|
+
return rbln_config
|
66
|
+
|
79
67
|
@classmethod
|
80
68
|
def _get_rbln_config(
|
81
69
|
cls,
|
@@ -119,7 +107,7 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
119
107
|
|
120
108
|
|
121
109
|
class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
|
122
|
-
|
110
|
+
pass
|
123
111
|
|
124
112
|
|
125
113
|
class _VisionEncoder(torch.nn.Module):
|
@@ -133,20 +121,6 @@ class _VisionEncoder(torch.nn.Module):
|
|
133
121
|
|
134
122
|
|
135
123
|
class RBLNCLIPVisionModel(RBLNModel):
|
136
|
-
original_model_class = CLIPVisionModel
|
137
|
-
original_config_class = CLIPVisionConfig
|
138
|
-
|
139
|
-
@classmethod
|
140
|
-
def from_pretrained(cls, *args, **kwargs):
|
141
|
-
configtmp = AutoConfig.from_pretrained
|
142
|
-
modeltmp = AutoModel.from_pretrained
|
143
|
-
AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
|
144
|
-
AutoModel.from_pretrained = cls.original_model_class.from_pretrained
|
145
|
-
rt = super().from_pretrained(*args, **kwargs)
|
146
|
-
AutoConfig.from_pretrained = configtmp
|
147
|
-
AutoModel.from_pretrained = modeltmp
|
148
|
-
return rt
|
149
|
-
|
150
124
|
@classmethod
|
151
125
|
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
152
126
|
return _VisionEncoder(model).eval()
|
@@ -155,7 +129,7 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
155
129
|
def _get_rbln_config(
|
156
130
|
cls,
|
157
131
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
158
|
-
model_config: "
|
132
|
+
model_config: "CLIPVisionConfig",
|
159
133
|
rbln_kwargs: Dict[str, Any] = {},
|
160
134
|
) -> RBLNConfig:
|
161
135
|
rbln_batch_size = rbln_kwargs.get("batch_size", 1)
|
@@ -22,12 +22,7 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
from .decoderonly_architecture import (
|
25
|
-
DecoderOnlyAttention,
|
26
|
-
DecoderOnlyDecoderLayer,
|
27
|
-
DecoderOnlyModel,
|
28
25
|
DecoderOnlyWrapper,
|
29
|
-
DynamicNTKScalingRotaryEmbedding,
|
30
|
-
LinearScalingRotaryEmbedding,
|
31
26
|
RotaryEmbedding,
|
32
27
|
apply_rotary_pos_emb,
|
33
28
|
rotate_half,
|