optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_config.py
CHANGED
@@ -23,24 +23,38 @@
|
|
23
23
|
|
24
24
|
import copy
|
25
25
|
import json
|
26
|
-
from collections import UserDict
|
27
26
|
from dataclasses import asdict, dataclass
|
28
27
|
from pathlib import Path
|
29
28
|
from typing import Any, Dict, List, Optional, Tuple
|
30
29
|
|
30
|
+
import rebel
|
31
31
|
import torch
|
32
32
|
|
33
|
+
from .__version__ import __version__
|
34
|
+
from .utils.runtime_utils import ContextRblnConfig
|
35
|
+
|
33
36
|
|
34
37
|
DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
35
38
|
DEFAULT_MOD_NAME = "default"
|
36
39
|
|
37
40
|
|
38
41
|
@dataclass
|
39
|
-
class
|
42
|
+
class RBLNCompileConfig:
|
43
|
+
"""
|
44
|
+
Configuration for RBLN compilation.
|
45
|
+
|
46
|
+
Attributes:
|
47
|
+
compiled_model_name (str): Name of the compiled model.
|
48
|
+
mod_name (str): Name of the RBLN module.
|
49
|
+
input_info (List[Tuple[str, Tuple[int], Optional[str]]]): Information about input tensors.
|
50
|
+
fusion (Optional[bool]): Whether to use fusion optimization.
|
51
|
+
npu (Optional[str]): NPU configuration.
|
52
|
+
tensor_parallel_size (Optional[int]): Size for tensor parallelism.
|
53
|
+
"""
|
54
|
+
|
40
55
|
compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
|
41
|
-
|
56
|
+
mod_name: str = DEFAULT_MOD_NAME
|
42
57
|
input_info: List[Tuple[str, Tuple[int], Optional[str]]] = None
|
43
|
-
batch_size: Optional[int] = None
|
44
58
|
fusion: Optional[bool] = None
|
45
59
|
npu: Optional[str] = None
|
46
60
|
tensor_parallel_size: Optional[int] = None
|
@@ -48,8 +62,14 @@ class RBLNRuntimeConfig:
|
|
48
62
|
@staticmethod
|
49
63
|
def normalize_dtype(dtype):
|
50
64
|
"""
|
51
|
-
framework
|
65
|
+
Convert framework-specific dtype to string representation.
|
52
66
|
i.e. torch.float32 -> "float32"
|
67
|
+
|
68
|
+
Args:
|
69
|
+
dtype: The input dtype (can be string, torch dtype, or numpy dtype).
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
str: The normalized string representation of the dtype.
|
53
73
|
"""
|
54
74
|
if isinstance(dtype, str):
|
55
75
|
return dtype
|
@@ -60,13 +80,12 @@ class RBLNRuntimeConfig:
|
|
60
80
|
return dtype
|
61
81
|
|
62
82
|
def __post_init__(self):
|
63
|
-
self.input_info = [(i[0], i[1],
|
83
|
+
self.input_info = [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
|
64
84
|
|
65
|
-
def update(self,
|
85
|
+
def update(self, kwargs: Dict[str, Any]):
|
66
86
|
self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
|
67
|
-
self.
|
87
|
+
self.mod_name = kwargs.get("mod_name", self.mod_name)
|
68
88
|
self.input_info = kwargs.get("input_info", self.input_info)
|
69
|
-
self.batch_size = kwargs.get("batch_size", self.batch_size)
|
70
89
|
self.fusion = kwargs.get("fusion", self.fusion)
|
71
90
|
self.npu = kwargs.get("npu", self.npu)
|
72
91
|
self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
|
@@ -86,84 +105,191 @@ class RBLNRuntimeConfig:
|
|
86
105
|
return asdict(self)
|
87
106
|
|
88
107
|
|
89
|
-
|
90
|
-
|
91
|
-
"""Configurations for RBLN model compilation and inference.
|
108
|
+
RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map"]
|
109
|
+
COMPILE_KEYWORDS = ["compiled_model_name", "mod_name", "input_info", "fusion", "npu", "tensor_parallel_size"]
|
92
110
|
|
93
|
-
Args:
|
94
|
-
_rbln_meta (Dict[str, Any], optional):
|
95
|
-
Any rbln-specific configurations.
|
96
|
-
(i.e. max_seq_len for language models, image_size for image models).
|
97
|
-
Defaults to None.
|
98
|
-
"""
|
99
|
-
super().__init__(runtime_cfgs)
|
100
|
-
if _rbln_meta:
|
101
|
-
self.meta = _rbln_meta
|
102
|
-
else:
|
103
|
-
self.meta: Dict[str, Any] = {}
|
104
111
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
names = [None] * len(rbln_configs) if names is None else names
|
109
|
-
runtime_cfgs = []
|
110
|
-
for name, cfg in zip(names, rbln_configs):
|
111
|
-
if len(cfg) > 1:
|
112
|
-
msg = (
|
113
|
-
"`from_rbln_configs` requires exact one `RBLNRuntimeConfig` for each `RBLNConfig`."
|
114
|
-
f"But got {len(cfg)} `RBLNRuntimeConfig`."
|
115
|
-
)
|
116
|
-
raise RuntimeError(msg)
|
112
|
+
class RBLNConfig:
|
113
|
+
"""
|
114
|
+
Configuration for single RBLN OptimizedModel, representing multiple compiled models.
|
117
115
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
116
|
+
Attributes:
|
117
|
+
compile_cfgs (List[RBLNCompileConfig]): Compilation configurations.
|
118
|
+
meta (dict): Metadata including version and class information.
|
119
|
+
runtime_cfg (dict): Runtime-specific configuration.
|
120
|
+
"""
|
123
121
|
|
124
|
-
|
125
|
-
|
122
|
+
# It represents multiple compiled model, one of each can have multiple runtimes.
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
rbln_cls,
|
126
|
+
compile_cfgs: List[RBLNCompileConfig],
|
127
|
+
rbln_kwargs=None,
|
128
|
+
meta=None,
|
129
|
+
) -> None:
|
130
|
+
if rbln_kwargs is None:
|
131
|
+
rbln_kwargs = {}
|
132
|
+
else:
|
133
|
+
rbln_kwargs = copy.deepcopy(rbln_kwargs)
|
126
134
|
|
127
|
-
|
135
|
+
# meta : class, version and other informations.
|
136
|
+
if meta is None:
|
137
|
+
self.meta = {"version": __version__, "cls": rbln_cls}
|
138
|
+
else:
|
139
|
+
self.meta = meta
|
128
140
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
+
# compile_cfgs : compile args for each runtime
|
142
|
+
self.compile_cfgs = compile_cfgs
|
143
|
+
for compile_cfg in self.compile_cfgs:
|
144
|
+
compile_cfg.update(rbln_kwargs)
|
145
|
+
for K in COMPILE_KEYWORDS:
|
146
|
+
rbln_kwargs.pop(K, None)
|
147
|
+
|
148
|
+
# runtime_cfg : Values that don't be saved / loaded.
|
149
|
+
self.runtime_cfg = {}
|
150
|
+
for runtime_key in RUNTIME_KEYWORDS:
|
151
|
+
if runtime_key in rbln_kwargs:
|
152
|
+
self.runtime_cfg[runtime_key] = rbln_kwargs.pop(runtime_key)
|
153
|
+
|
154
|
+
# model_cfg : All user-provided values such as "max_seq_len".
|
155
|
+
self.model_cfg: Dict[str, Any] = rbln_kwargs
|
141
156
|
|
142
157
|
def save(self, dir_path: str):
|
143
158
|
dir_path = Path(dir_path)
|
144
|
-
|
145
|
-
|
159
|
+
|
160
|
+
s_json = {}
|
161
|
+
compile_cfgs = [asdict(cfg) for cfg in self.compile_cfgs]
|
162
|
+
s_json["_compile_cfgs"] = compile_cfgs
|
163
|
+
s_json["_meta"] = self.meta
|
164
|
+
s_json.update(self.model_cfg)
|
165
|
+
|
146
166
|
with open(dir_path / "rbln_config.json", "w") as jsonf:
|
147
|
-
json.dump(
|
167
|
+
json.dump(s_json, jsonf, indent=2)
|
148
168
|
|
149
|
-
@
|
150
|
-
def load(dir_path: str) -> "RBLNConfig":
|
169
|
+
@classmethod
|
170
|
+
def load(cls, dir_path: str) -> "RBLNConfig":
|
151
171
|
dir_path = Path(dir_path)
|
152
172
|
with open(dir_path / "rbln_config.json", "r") as jsonf:
|
153
173
|
config_file = json.load(jsonf)
|
154
|
-
return RBLNConfig.fromdict(config_file)
|
155
174
|
|
156
|
-
|
157
|
-
|
158
|
-
|
175
|
+
return cls.fromdict(config_file)
|
176
|
+
|
177
|
+
@classmethod
|
178
|
+
def fromdict(cls, dic: dict):
|
179
|
+
compile_cfgs = dic.pop("_compile_cfgs")
|
180
|
+
compile_cfgs = [RBLNCompileConfig(**cfg) for cfg in compile_cfgs]
|
181
|
+
|
182
|
+
meta = dic.pop("_meta")
|
183
|
+
rbln_cls = meta["cls"]
|
184
|
+
|
185
|
+
rbln_kwargs = dic
|
186
|
+
return cls(rbln_cls=rbln_cls, compile_cfgs=compile_cfgs, rbln_kwargs=rbln_kwargs, meta=meta)
|
187
|
+
|
188
|
+
def update_runtime_cfg(self, rbln_kwargs: Dict[str, Any]):
|
189
|
+
keys = list(rbln_kwargs.keys())
|
190
|
+
for key in keys:
|
191
|
+
if key in RUNTIME_KEYWORDS:
|
192
|
+
self.runtime_cfg[key] = rbln_kwargs[key]
|
193
|
+
|
194
|
+
def __repr__(self):
|
195
|
+
compile_cfgs_repr = [f"\n {cfg!r}" for cfg in self.compile_cfgs]
|
196
|
+
return (
|
197
|
+
f"RBLNConfig(\n"
|
198
|
+
f" rbln_cls={self.meta['cls']},\n"
|
199
|
+
f" version='{self.meta['version']}',\n"
|
200
|
+
f" compile_cfgs=[{''.join(compile_cfgs_repr)}\n ],\n"
|
201
|
+
f" model_cfg={self.model_cfg},\n"
|
202
|
+
f" runtime_cfg={self.runtime_cfg}\n"
|
203
|
+
f")"
|
204
|
+
)
|
205
|
+
|
206
|
+
@property
|
207
|
+
def create_runtimes(self):
|
208
|
+
context = ContextRblnConfig.get_current_context()["create_runtimes"]
|
209
|
+
if context is not None:
|
210
|
+
return context
|
211
|
+
elif self.runtime_cfg.get("create_runtimes", None) is None:
|
212
|
+
return rebel.npu_is_available()
|
213
|
+
return self.runtime_cfg["create_runtimes"]
|
214
|
+
|
215
|
+
@property
|
216
|
+
def optimize_host_memory(self):
|
217
|
+
context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
|
218
|
+
if context is not None:
|
219
|
+
return context
|
220
|
+
elif self.runtime_cfg.get("optimize_host_memory", None) is None:
|
221
|
+
return True
|
222
|
+
return self.runtime_cfg["optimize_host_memory"]
|
223
|
+
|
224
|
+
@property
|
225
|
+
def device(self):
|
226
|
+
context = ContextRblnConfig.get_current_context()["device"]
|
227
|
+
if context:
|
228
|
+
return context
|
229
|
+
elif self.runtime_cfg.get("device", None) is None:
|
230
|
+
return 0
|
231
|
+
return self.runtime_cfg["device"]
|
232
|
+
|
233
|
+
@property
|
234
|
+
def device_map(self):
|
235
|
+
context = ContextRblnConfig.get_current_context()["device_map"]
|
236
|
+
if context:
|
237
|
+
return context
|
238
|
+
elif self.runtime_cfg.get("device_map", None) is None:
|
239
|
+
rbln_device_map = {}
|
240
|
+
device_val = self.device
|
241
|
+
for cfg in self.compile_cfgs:
|
242
|
+
rbln_device_map[cfg.compiled_model_name] = device_val
|
243
|
+
return rbln_device_map
|
244
|
+
return self.runtime_cfg["device_map"]
|
245
|
+
|
246
|
+
|
247
|
+
def use_rbln_config(fn):
|
248
|
+
"""
|
249
|
+
If the function uses rbln_config and kwargs,
|
250
|
+
then extract `rbln_` prefix from kwargs.
|
251
|
+
|
252
|
+
If rbln_config is already an instance of RBLNConfig, then pass.
|
253
|
+
"""
|
254
|
+
|
255
|
+
def merged_rbln_config_fn(*args, **kwargs):
|
256
|
+
rbln_kwargs = kwargs.pop("rbln_kwargs", None)
|
257
|
+
if rbln_kwargs is not None:
|
258
|
+
raise KeyError("`rbln_kwargs` cannot be specified when using `rbln_config`!")
|
259
|
+
|
260
|
+
rbln_config = kwargs.pop("rbln_config", None)
|
261
|
+
|
262
|
+
keys = list(kwargs.keys())
|
263
|
+
rbln_kwargs = {key[5:]: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
|
264
|
+
|
265
|
+
if isinstance(rbln_config, RBLNConfig):
|
266
|
+
# merge runtime kwargs if exists.
|
267
|
+
runtime_rbln_kwargs = {k: rbln_kwargs.pop(k) for k in RUNTIME_KEYWORDS if k in rbln_kwargs}
|
268
|
+
|
269
|
+
# ignore internal keys and recover "rbln_" prefix
|
270
|
+
RBLN_INTERNAL_KEYS = {"compiled_models", "submodules"}
|
271
|
+
internal_kwargs = {"rbln_" + k: rbln_kwargs.pop(k) for k in RBLN_INTERNAL_KEYS if k in rbln_kwargs}
|
272
|
+
|
273
|
+
if len(rbln_kwargs) > 0:
|
274
|
+
raise KeyError(
|
275
|
+
f"Failed to merging function argument : {rbln_kwargs.keys()}. "
|
276
|
+
"If you passed `rbln_config` an instance of `RBLNConfig`, "
|
277
|
+
"then none `rbln_` prefixes are allowed to be passed."
|
278
|
+
)
|
279
|
+
rbln_config.update_runtime_cfg(runtime_rbln_kwargs)
|
280
|
+
return fn(*args, **kwargs, **internal_kwargs, rbln_config=rbln_config)
|
281
|
+
|
282
|
+
elif rbln_config is None:
|
283
|
+
rbln_config_dict = {}
|
159
284
|
|
160
|
-
@staticmethod
|
161
|
-
def fromdict(dic: dict):
|
162
|
-
runtime_cfgs = {
|
163
|
-
k: [RBLNRuntimeConfig(**cfg) for cfg in cfgs] for k, cfgs in dic.items() if k != "rbln_config_meta"
|
164
|
-
}
|
165
|
-
if "rbln_config_meta" in dic:
|
166
|
-
meta = dic["rbln_config_meta"]
|
167
285
|
else:
|
168
|
-
|
169
|
-
|
286
|
+
rbln_config_dict = rbln_config
|
287
|
+
|
288
|
+
for key in rbln_config_dict:
|
289
|
+
if key in rbln_kwargs:
|
290
|
+
raise KeyError(f"Duplicated key in both `rbln_config` and rbln_{key}.")
|
291
|
+
|
292
|
+
rbln_kwargs.update(rbln_config_dict)
|
293
|
+
return fn(*args, **kwargs, rbln_config=rbln_kwargs)
|
294
|
+
|
295
|
+
return merged_rbln_config_fn
|
@@ -30,17 +30,38 @@ _import_structure = {
|
|
30
30
|
"cache_utils": ["RebelDynamicCache"],
|
31
31
|
"generation": ["BatchTextIteratorStreamer"],
|
32
32
|
"models": [
|
33
|
+
"RBLNAutoModel",
|
34
|
+
"RBLNAutoModelForAudioClassification",
|
35
|
+
"RBLNAutoModelForCausalLM",
|
36
|
+
"RBLNAutoModelForCTC",
|
37
|
+
"RBLNAutoModelForDepthEstimation",
|
38
|
+
"RBLNAutoModelForImageClassification",
|
39
|
+
"RBLNAutoModelForMaskedLM",
|
40
|
+
"RBLNAutoModelForQuestionAnswering",
|
41
|
+
"RBLNAutoModelForSeq2SeqLM",
|
42
|
+
"RBLNAutoModelForSequenceClassification",
|
43
|
+
"RBLNAutoModelForSpeechSeq2Seq",
|
44
|
+
"RBLNAutoModelForVision2Seq",
|
45
|
+
"RBLNBartForConditionalGeneration",
|
46
|
+
"RBLNBartModel",
|
47
|
+
"RBLNBertModel",
|
33
48
|
"RBLNCLIPTextModel",
|
34
49
|
"RBLNCLIPTextModelWithProjection",
|
50
|
+
"RBLNCLIPVisionModel",
|
35
51
|
"RBLNDPTForDepthEstimation",
|
52
|
+
"RBLNExaoneForCausalLM",
|
36
53
|
"RBLNGemmaForCausalLM",
|
37
54
|
"RBLNGPT2LMHeadModel",
|
55
|
+
"RBLNQwen2ForCausalLM",
|
38
56
|
"RBLNWav2Vec2ForCTC",
|
39
57
|
"RBLNWhisperForConditionalGeneration",
|
40
58
|
"RBLNLlamaForCausalLM",
|
59
|
+
"RBLNPhiForCausalLM",
|
60
|
+
"RBLNT5ForConditionalGeneration",
|
61
|
+
"RBLNLlavaNextForConditionalGeneration",
|
41
62
|
"RBLNMidmLMHeadModel",
|
42
|
-
"RBLNMistralForCausalLM",
|
43
63
|
"RBLNXLMRobertaModel",
|
64
|
+
"RBLNMistralForCausalLM",
|
44
65
|
],
|
45
66
|
}
|
46
67
|
|
@@ -48,14 +69,35 @@ if TYPE_CHECKING:
|
|
48
69
|
from .cache_utils import RebelDynamicCache
|
49
70
|
from .generation import BatchTextIteratorStreamer
|
50
71
|
from .models import (
|
72
|
+
RBLNAutoModel,
|
73
|
+
RBLNAutoModelForAudioClassification,
|
74
|
+
RBLNAutoModelForCausalLM,
|
75
|
+
RBLNAutoModelForCTC,
|
76
|
+
RBLNAutoModelForDepthEstimation,
|
77
|
+
RBLNAutoModelForImageClassification,
|
78
|
+
RBLNAutoModelForMaskedLM,
|
79
|
+
RBLNAutoModelForQuestionAnswering,
|
80
|
+
RBLNAutoModelForSeq2SeqLM,
|
81
|
+
RBLNAutoModelForSequenceClassification,
|
82
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
83
|
+
RBLNAutoModelForVision2Seq,
|
84
|
+
RBLNBartForConditionalGeneration,
|
85
|
+
RBLNBartModel,
|
86
|
+
RBLNBertModel,
|
51
87
|
RBLNCLIPTextModel,
|
52
88
|
RBLNCLIPTextModelWithProjection,
|
89
|
+
RBLNCLIPVisionModel,
|
53
90
|
RBLNDPTForDepthEstimation,
|
91
|
+
RBLNExaoneForCausalLM,
|
54
92
|
RBLNGemmaForCausalLM,
|
55
93
|
RBLNGPT2LMHeadModel,
|
56
94
|
RBLNLlamaForCausalLM,
|
95
|
+
RBLNLlavaNextForConditionalGeneration,
|
57
96
|
RBLNMidmLMHeadModel,
|
58
97
|
RBLNMistralForCausalLM,
|
98
|
+
RBLNPhiForCausalLM,
|
99
|
+
RBLNQwen2ForCausalLM,
|
100
|
+
RBLNT5ForConditionalGeneration,
|
59
101
|
RBLNWav2Vec2ForCTC,
|
60
102
|
RBLNWhisperForConditionalGeneration,
|
61
103
|
RBLNXLMRobertaModel,
|
@@ -21,13 +21,35 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
|
24
|
+
|
25
|
+
from .auto import (
|
26
|
+
RBLNAutoModel,
|
27
|
+
RBLNAutoModelForAudioClassification,
|
28
|
+
RBLNAutoModelForCausalLM,
|
29
|
+
RBLNAutoModelForCTC,
|
30
|
+
RBLNAutoModelForDepthEstimation,
|
31
|
+
RBLNAutoModelForImageClassification,
|
32
|
+
RBLNAutoModelForMaskedLM,
|
33
|
+
RBLNAutoModelForQuestionAnswering,
|
34
|
+
RBLNAutoModelForSeq2SeqLM,
|
35
|
+
RBLNAutoModelForSequenceClassification,
|
36
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
37
|
+
RBLNAutoModelForVision2Seq,
|
38
|
+
)
|
39
|
+
from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
|
40
|
+
from .bert import RBLNBertModel
|
41
|
+
from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
|
25
42
|
from .dpt import RBLNDPTForDepthEstimation
|
43
|
+
from .exaone import RBLNExaoneForCausalLM
|
26
44
|
from .gemma import RBLNGemmaForCausalLM
|
27
45
|
from .gpt2 import RBLNGPT2LMHeadModel
|
28
46
|
from .llama import RBLNLlamaForCausalLM
|
47
|
+
from .llava_next import RBLNLlavaNextForConditionalGeneration
|
29
48
|
from .midm import RBLNMidmLMHeadModel
|
30
49
|
from .mistral import RBLNMistralForCausalLM
|
50
|
+
from .phi import RBLNPhiForCausalLM
|
51
|
+
from .qwen2 import RBLNQwen2ForCausalLM
|
52
|
+
from .t5 import RBLNT5ForConditionalGeneration
|
31
53
|
from .wav2vec2 import RBLNWav2Vec2ForCTC
|
32
54
|
from .whisper import RBLNWhisperForConditionalGeneration
|
33
55
|
from .xlm_roberta import RBLNXLMRobertaModel
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .modeling_auto import (
|
2
|
+
RBLNAutoModel,
|
3
|
+
RBLNAutoModelForAudioClassification,
|
4
|
+
RBLNAutoModelForCausalLM,
|
5
|
+
RBLNAutoModelForCTC,
|
6
|
+
RBLNAutoModelForDepthEstimation,
|
7
|
+
RBLNAutoModelForImageClassification,
|
8
|
+
RBLNAutoModelForMaskedLM,
|
9
|
+
RBLNAutoModelForQuestionAnswering,
|
10
|
+
RBLNAutoModelForSeq2SeqLM,
|
11
|
+
RBLNAutoModelForSequenceClassification,
|
12
|
+
RBLNAutoModelForSpeechSeq2Seq,
|
13
|
+
RBLNAutoModelForVision2Seq,
|
14
|
+
)
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import importlib
|
25
|
+
|
26
|
+
from transformers import AutoConfig
|
27
|
+
|
28
|
+
|
29
|
+
class _BaseAutoModelClass:
|
30
|
+
# Base class for auto models.
|
31
|
+
_model_mapping = None
|
32
|
+
|
33
|
+
def __init__(self, *args, **kwargs):
|
34
|
+
raise EnvironmentError(
|
35
|
+
f"{self.__class__.__name__} is designed to be instantiated "
|
36
|
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
37
|
+
f"`{self.__class__.__name__}.from_config(config)` methods."
|
38
|
+
)
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def get_rbln_cls(
|
42
|
+
cls,
|
43
|
+
model_id,
|
44
|
+
*args,
|
45
|
+
**kwargs,
|
46
|
+
):
|
47
|
+
# kwargs.update({"return_unused_kwargs": True})
|
48
|
+
config = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, **kwargs)[0]
|
49
|
+
|
50
|
+
if len(config.architectures) > 1:
|
51
|
+
raise ValueError(
|
52
|
+
f"Model with ID '{model_id}' has multiple architectures defined in the configuration: "
|
53
|
+
f"{config.architectures}. `_BaseAutoModelClass` require exactly one architecture. "
|
54
|
+
)
|
55
|
+
|
56
|
+
architecture_name = config.architectures[0]
|
57
|
+
if architecture_name not in cls._model_mapping.values():
|
58
|
+
raise ValueError(
|
59
|
+
f"The 'RBLN{architecture_name}' architecture is not supported by `{cls.__name__}.from_pretrained()`."
|
60
|
+
"Please use the appropriate class's `from_pretrained()` method to load this model."
|
61
|
+
)
|
62
|
+
|
63
|
+
rbln_class_name = "RBLN" + architecture_name
|
64
|
+
module = importlib.import_module("optimum.rbln")
|
65
|
+
|
66
|
+
try:
|
67
|
+
rbln_cls = getattr(module, rbln_class_name)
|
68
|
+
except AttributeError as e:
|
69
|
+
raise AttributeError(
|
70
|
+
f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{model_id}'. "
|
71
|
+
"Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
|
72
|
+
) from e
|
73
|
+
|
74
|
+
return rbln_cls
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def from_pretrained(
|
78
|
+
cls,
|
79
|
+
model_id,
|
80
|
+
*args,
|
81
|
+
**kwargs,
|
82
|
+
):
|
83
|
+
rbln_cls = cls.get_rbln_cls(model_id, *args, **kwargs)
|
84
|
+
return rbln_cls.from_pretrained(model_id, *args, **kwargs)
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from transformers.models.auto.modeling_auto import (
|
25
|
+
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
|
26
|
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
27
|
+
MODEL_FOR_CTC_MAPPING_NAMES,
|
28
|
+
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
|
29
|
+
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
30
|
+
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
|
31
|
+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
|
32
|
+
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
33
|
+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
34
|
+
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
35
|
+
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
36
|
+
MODEL_MAPPING_NAMES,
|
37
|
+
)
|
38
|
+
|
39
|
+
from .auto_factory import _BaseAutoModelClass
|
40
|
+
|
41
|
+
|
42
|
+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
|
43
|
+
{
|
44
|
+
"midm": "MidmLMHeadModel",
|
45
|
+
"exaone": "ExaoneForCausalLM",
|
46
|
+
}
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class RBLNAutoModel(_BaseAutoModelClass):
|
51
|
+
_model_mapping = MODEL_MAPPING_NAMES
|
52
|
+
|
53
|
+
|
54
|
+
class RBLNAutoModelForCTC(_BaseAutoModelClass):
|
55
|
+
_model_mapping = MODEL_FOR_CTC_MAPPING_NAMES
|
56
|
+
|
57
|
+
|
58
|
+
class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
|
59
|
+
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
60
|
+
|
61
|
+
|
62
|
+
class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
63
|
+
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
64
|
+
|
65
|
+
|
66
|
+
class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
67
|
+
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
68
|
+
|
69
|
+
|
70
|
+
class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
|
71
|
+
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
72
|
+
|
73
|
+
|
74
|
+
class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
|
75
|
+
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
76
|
+
|
77
|
+
|
78
|
+
class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
79
|
+
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
80
|
+
|
81
|
+
|
82
|
+
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
83
|
+
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
84
|
+
|
85
|
+
|
86
|
+
class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
|
87
|
+
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
88
|
+
|
89
|
+
|
90
|
+
class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
|
91
|
+
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
92
|
+
|
93
|
+
|
94
|
+
class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
95
|
+
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|