optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -23,24 +23,38 @@
23
23
 
24
24
  import copy
25
25
  import json
26
- from collections import UserDict
27
26
  from dataclasses import asdict, dataclass
28
27
  from pathlib import Path
29
28
  from typing import Any, Dict, List, Optional, Tuple
30
29
 
30
+ import rebel
31
31
  import torch
32
32
 
33
+ from .__version__ import __version__
34
+ from .utils.runtime_utils import ContextRblnConfig
35
+
33
36
 
34
37
  DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
35
38
  DEFAULT_MOD_NAME = "default"
36
39
 
37
40
 
38
41
  @dataclass
39
- class RBLNRuntimeConfig:
42
+ class RBLNCompileConfig:
43
+ """
44
+ Configuration for RBLN compilation.
45
+
46
+ Attributes:
47
+ compiled_model_name (str): Name of the compiled model.
48
+ mod_name (str): Name of the RBLN module.
49
+ input_info (List[Tuple[str, Tuple[int], Optional[str]]]): Information about input tensors.
50
+ fusion (Optional[bool]): Whether to use fusion optimization.
51
+ npu (Optional[str]): NPU configuration.
52
+ tensor_parallel_size (Optional[int]): Size for tensor parallelism.
53
+ """
54
+
40
55
  compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
41
- rbln_mod_name: str = DEFAULT_MOD_NAME
56
+ mod_name: str = DEFAULT_MOD_NAME
42
57
  input_info: List[Tuple[str, Tuple[int], Optional[str]]] = None
43
- batch_size: Optional[int] = None
44
58
  fusion: Optional[bool] = None
45
59
  npu: Optional[str] = None
46
60
  tensor_parallel_size: Optional[int] = None
@@ -48,8 +62,14 @@ class RBLNRuntimeConfig:
48
62
  @staticmethod
49
63
  def normalize_dtype(dtype):
50
64
  """
51
- framework's dtype to string.
65
+ Convert framework-specific dtype to string representation.
52
66
  i.e. torch.float32 -> "float32"
67
+
68
+ Args:
69
+ dtype: The input dtype (can be string, torch dtype, or numpy dtype).
70
+
71
+ Returns:
72
+ str: The normalized string representation of the dtype.
53
73
  """
54
74
  if isinstance(dtype, str):
55
75
  return dtype
@@ -60,13 +80,12 @@ class RBLNRuntimeConfig:
60
80
  return dtype
61
81
 
62
82
  def __post_init__(self):
63
- self.input_info = [(i[0], i[1], RBLNRuntimeConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
83
+ self.input_info = [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
64
84
 
65
- def update(self, **kwargs):
85
+ def update(self, kwargs: Dict[str, Any]):
66
86
  self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
67
- self.rbln_mod_name = kwargs.get("rbln_mod_name", self.rbln_mod_name)
87
+ self.mod_name = kwargs.get("mod_name", self.mod_name)
68
88
  self.input_info = kwargs.get("input_info", self.input_info)
69
- self.batch_size = kwargs.get("batch_size", self.batch_size)
70
89
  self.fusion = kwargs.get("fusion", self.fusion)
71
90
  self.npu = kwargs.get("npu", self.npu)
72
91
  self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
@@ -86,84 +105,191 @@ class RBLNRuntimeConfig:
86
105
  return asdict(self)
87
106
 
88
107
 
89
- class RBLNConfig(UserDict):
90
- def __init__(self, runtime_cfgs: Dict[str, List[RBLNRuntimeConfig]], _rbln_meta: Dict[str, Any] = None):
91
- """Configurations for RBLN model compilation and inference.
108
+ RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map"]
109
+ COMPILE_KEYWORDS = ["compiled_model_name", "mod_name", "input_info", "fusion", "npu", "tensor_parallel_size"]
92
110
 
93
- Args:
94
- _rbln_meta (Dict[str, Any], optional):
95
- Any rbln-specific configurations.
96
- (i.e. max_seq_len for language models, image_size for image models).
97
- Defaults to None.
98
- """
99
- super().__init__(runtime_cfgs)
100
- if _rbln_meta:
101
- self.meta = _rbln_meta
102
- else:
103
- self.meta: Dict[str, Any] = {}
104
111
 
105
- @staticmethod
106
- def from_rbln_configs(rbln_configs: List["RBLNConfig"], names: Optional[List[str]] = None) -> "RBLNConfig":
107
- # assume each rbln_config has exact one rbln_runtime_config
108
- names = [None] * len(rbln_configs) if names is None else names
109
- runtime_cfgs = []
110
- for name, cfg in zip(names, rbln_configs):
111
- if len(cfg) > 1:
112
- msg = (
113
- "`from_rbln_configs` requires exact one `RBLNRuntimeConfig` for each `RBLNConfig`."
114
- f"But got {len(cfg)} `RBLNRuntimeConfig`."
115
- )
116
- raise RuntimeError(msg)
112
+ class RBLNConfig:
113
+ """
114
+ Configuration for single RBLN OptimizedModel, representing multiple compiled models.
117
115
 
118
- runtime_cfg = cfg[list(cfg.keys())[0]][0]
119
- runtime_cfg = copy.deepcopy(runtime_cfg)
120
- if name is not None:
121
- runtime_cfg.compiled_model_name = name
122
- runtime_cfgs.append(runtime_cfg)
116
+ Attributes:
117
+ compile_cfgs (List[RBLNCompileConfig]): Compilation configurations.
118
+ meta (dict): Metadata including version and class information.
119
+ runtime_cfg (dict): Runtime-specific configuration.
120
+ """
123
121
 
124
- metas = [cfg.meta for cfg in rbln_configs]
125
- merged_meta = {k: v for meta in metas for k, v in meta.items()}
122
+ # It represents multiple compiled model, one of each can have multiple runtimes.
123
+ def __init__(
124
+ self,
125
+ rbln_cls,
126
+ compile_cfgs: List[RBLNCompileConfig],
127
+ rbln_kwargs=None,
128
+ meta=None,
129
+ ) -> None:
130
+ if rbln_kwargs is None:
131
+ rbln_kwargs = {}
132
+ else:
133
+ rbln_kwargs = copy.deepcopy(rbln_kwargs)
126
134
 
127
- return RBLNConfig.from_rbln_runtime_configs(runtime_cfgs, _rbln_meta=merged_meta)
135
+ # meta : class, version and other informations.
136
+ if meta is None:
137
+ self.meta = {"version": __version__, "cls": rbln_cls}
138
+ else:
139
+ self.meta = meta
128
140
 
129
- @staticmethod
130
- def from_rbln_runtime_configs(
131
- rbln_runtime_configs: List[RBLNRuntimeConfig],
132
- _rbln_meta: Dict[str, Any] = None,
133
- ) -> "RBLNConfig":
134
- cfgs: Dict[str, List[RBLNRuntimeConfig]] = {}
135
- for rbln_runtime_config in rbln_runtime_configs:
136
- if rbln_runtime_config.compiled_model_name in cfgs:
137
- cfgs[rbln_runtime_config.compiled_model_name].append(rbln_runtime_config)
138
- else:
139
- cfgs[rbln_runtime_config.compiled_model_name] = [rbln_runtime_config]
140
- return RBLNConfig(cfgs, _rbln_meta=_rbln_meta)
141
+ # compile_cfgs : compile args for each runtime
142
+ self.compile_cfgs = compile_cfgs
143
+ for compile_cfg in self.compile_cfgs:
144
+ compile_cfg.update(rbln_kwargs)
145
+ for K in COMPILE_KEYWORDS:
146
+ rbln_kwargs.pop(K, None)
147
+
148
+ # runtime_cfg : Values that don't be saved / loaded.
149
+ self.runtime_cfg = {}
150
+ for runtime_key in RUNTIME_KEYWORDS:
151
+ if runtime_key in rbln_kwargs:
152
+ self.runtime_cfg[runtime_key] = rbln_kwargs.pop(runtime_key)
153
+
154
+ # model_cfg : All user-provided values such as "max_seq_len".
155
+ self.model_cfg: Dict[str, Any] = rbln_kwargs
141
156
 
142
157
  def save(self, dir_path: str):
143
158
  dir_path = Path(dir_path)
144
- data = self.asdict()
145
- data.update({"rbln_config_meta": self.meta})
159
+
160
+ s_json = {}
161
+ compile_cfgs = [asdict(cfg) for cfg in self.compile_cfgs]
162
+ s_json["_compile_cfgs"] = compile_cfgs
163
+ s_json["_meta"] = self.meta
164
+ s_json.update(self.model_cfg)
165
+
146
166
  with open(dir_path / "rbln_config.json", "w") as jsonf:
147
- json.dump(data, jsonf, indent=2)
167
+ json.dump(s_json, jsonf, indent=2)
148
168
 
149
- @staticmethod
150
- def load(dir_path: str) -> "RBLNConfig":
169
+ @classmethod
170
+ def load(cls, dir_path: str) -> "RBLNConfig":
151
171
  dir_path = Path(dir_path)
152
172
  with open(dir_path / "rbln_config.json", "r") as jsonf:
153
173
  config_file = json.load(jsonf)
154
- return RBLNConfig.fromdict(config_file)
155
174
 
156
- def asdict(self):
157
- dic = {k: [asdict(cfg) for cfg in cfgs] for k, cfgs in self.data.items()}
158
- return dic
175
+ return cls.fromdict(config_file)
176
+
177
+ @classmethod
178
+ def fromdict(cls, dic: dict):
179
+ compile_cfgs = dic.pop("_compile_cfgs")
180
+ compile_cfgs = [RBLNCompileConfig(**cfg) for cfg in compile_cfgs]
181
+
182
+ meta = dic.pop("_meta")
183
+ rbln_cls = meta["cls"]
184
+
185
+ rbln_kwargs = dic
186
+ return cls(rbln_cls=rbln_cls, compile_cfgs=compile_cfgs, rbln_kwargs=rbln_kwargs, meta=meta)
187
+
188
+ def update_runtime_cfg(self, rbln_kwargs: Dict[str, Any]):
189
+ keys = list(rbln_kwargs.keys())
190
+ for key in keys:
191
+ if key in RUNTIME_KEYWORDS:
192
+ self.runtime_cfg[key] = rbln_kwargs[key]
193
+
194
+ def __repr__(self):
195
+ compile_cfgs_repr = [f"\n {cfg!r}" for cfg in self.compile_cfgs]
196
+ return (
197
+ f"RBLNConfig(\n"
198
+ f" rbln_cls={self.meta['cls']},\n"
199
+ f" version='{self.meta['version']}',\n"
200
+ f" compile_cfgs=[{''.join(compile_cfgs_repr)}\n ],\n"
201
+ f" model_cfg={self.model_cfg},\n"
202
+ f" runtime_cfg={self.runtime_cfg}\n"
203
+ f")"
204
+ )
205
+
206
+ @property
207
+ def create_runtimes(self):
208
+ context = ContextRblnConfig.get_current_context()["create_runtimes"]
209
+ if context is not None:
210
+ return context
211
+ elif self.runtime_cfg.get("create_runtimes", None) is None:
212
+ return rebel.npu_is_available()
213
+ return self.runtime_cfg["create_runtimes"]
214
+
215
+ @property
216
+ def optimize_host_memory(self):
217
+ context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
218
+ if context is not None:
219
+ return context
220
+ elif self.runtime_cfg.get("optimize_host_memory", None) is None:
221
+ return True
222
+ return self.runtime_cfg["optimize_host_memory"]
223
+
224
+ @property
225
+ def device(self):
226
+ context = ContextRblnConfig.get_current_context()["device"]
227
+ if context:
228
+ return context
229
+ elif self.runtime_cfg.get("device", None) is None:
230
+ return 0
231
+ return self.runtime_cfg["device"]
232
+
233
+ @property
234
+ def device_map(self):
235
+ context = ContextRblnConfig.get_current_context()["device_map"]
236
+ if context:
237
+ return context
238
+ elif self.runtime_cfg.get("device_map", None) is None:
239
+ rbln_device_map = {}
240
+ device_val = self.device
241
+ for cfg in self.compile_cfgs:
242
+ rbln_device_map[cfg.compiled_model_name] = device_val
243
+ return rbln_device_map
244
+ return self.runtime_cfg["device_map"]
245
+
246
+
247
+ def use_rbln_config(fn):
248
+ """
249
+ If the function uses rbln_config and kwargs,
250
+ then extract `rbln_` prefix from kwargs.
251
+
252
+ If rbln_config is already an instance of RBLNConfig, then pass.
253
+ """
254
+
255
+ def merged_rbln_config_fn(*args, **kwargs):
256
+ rbln_kwargs = kwargs.pop("rbln_kwargs", None)
257
+ if rbln_kwargs is not None:
258
+ raise KeyError("`rbln_kwargs` cannot be specified when using `rbln_config`!")
259
+
260
+ rbln_config = kwargs.pop("rbln_config", None)
261
+
262
+ keys = list(kwargs.keys())
263
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
264
+
265
+ if isinstance(rbln_config, RBLNConfig):
266
+ # merge runtime kwargs if exists.
267
+ runtime_rbln_kwargs = {k: rbln_kwargs.pop(k) for k in RUNTIME_KEYWORDS if k in rbln_kwargs}
268
+
269
+ # ignore internal keys and recover "rbln_" prefix
270
+ RBLN_INTERNAL_KEYS = {"compiled_models", "submodules"}
271
+ internal_kwargs = {"rbln_" + k: rbln_kwargs.pop(k) for k in RBLN_INTERNAL_KEYS if k in rbln_kwargs}
272
+
273
+ if len(rbln_kwargs) > 0:
274
+ raise KeyError(
275
+ f"Failed to merging function argument : {rbln_kwargs.keys()}. "
276
+ "If you passed `rbln_config` an instance of `RBLNConfig`, "
277
+ "then none `rbln_` prefixes are allowed to be passed."
278
+ )
279
+ rbln_config.update_runtime_cfg(runtime_rbln_kwargs)
280
+ return fn(*args, **kwargs, **internal_kwargs, rbln_config=rbln_config)
281
+
282
+ elif rbln_config is None:
283
+ rbln_config_dict = {}
159
284
 
160
- @staticmethod
161
- def fromdict(dic: dict):
162
- runtime_cfgs = {
163
- k: [RBLNRuntimeConfig(**cfg) for cfg in cfgs] for k, cfgs in dic.items() if k != "rbln_config_meta"
164
- }
165
- if "rbln_config_meta" in dic:
166
- meta = dic["rbln_config_meta"]
167
285
  else:
168
- meta = None
169
- return RBLNConfig(runtime_cfgs, _rbln_meta=meta)
286
+ rbln_config_dict = rbln_config
287
+
288
+ for key in rbln_config_dict:
289
+ if key in rbln_kwargs:
290
+ raise KeyError(f"Duplicated key in both `rbln_config` and rbln_{key}.")
291
+
292
+ rbln_kwargs.update(rbln_config_dict)
293
+ return fn(*args, **kwargs, rbln_config=rbln_kwargs)
294
+
295
+ return merged_rbln_config_fn
@@ -30,17 +30,38 @@ _import_structure = {
30
30
  "cache_utils": ["RebelDynamicCache"],
31
31
  "generation": ["BatchTextIteratorStreamer"],
32
32
  "models": [
33
+ "RBLNAutoModel",
34
+ "RBLNAutoModelForAudioClassification",
35
+ "RBLNAutoModelForCausalLM",
36
+ "RBLNAutoModelForCTC",
37
+ "RBLNAutoModelForDepthEstimation",
38
+ "RBLNAutoModelForImageClassification",
39
+ "RBLNAutoModelForMaskedLM",
40
+ "RBLNAutoModelForQuestionAnswering",
41
+ "RBLNAutoModelForSeq2SeqLM",
42
+ "RBLNAutoModelForSequenceClassification",
43
+ "RBLNAutoModelForSpeechSeq2Seq",
44
+ "RBLNAutoModelForVision2Seq",
45
+ "RBLNBartForConditionalGeneration",
46
+ "RBLNBartModel",
47
+ "RBLNBertModel",
33
48
  "RBLNCLIPTextModel",
34
49
  "RBLNCLIPTextModelWithProjection",
50
+ "RBLNCLIPVisionModel",
35
51
  "RBLNDPTForDepthEstimation",
52
+ "RBLNExaoneForCausalLM",
36
53
  "RBLNGemmaForCausalLM",
37
54
  "RBLNGPT2LMHeadModel",
55
+ "RBLNQwen2ForCausalLM",
38
56
  "RBLNWav2Vec2ForCTC",
39
57
  "RBLNWhisperForConditionalGeneration",
40
58
  "RBLNLlamaForCausalLM",
59
+ "RBLNPhiForCausalLM",
60
+ "RBLNT5ForConditionalGeneration",
61
+ "RBLNLlavaNextForConditionalGeneration",
41
62
  "RBLNMidmLMHeadModel",
42
- "RBLNMistralForCausalLM",
43
63
  "RBLNXLMRobertaModel",
64
+ "RBLNMistralForCausalLM",
44
65
  ],
45
66
  }
46
67
 
@@ -48,14 +69,35 @@ if TYPE_CHECKING:
48
69
  from .cache_utils import RebelDynamicCache
49
70
  from .generation import BatchTextIteratorStreamer
50
71
  from .models import (
72
+ RBLNAutoModel,
73
+ RBLNAutoModelForAudioClassification,
74
+ RBLNAutoModelForCausalLM,
75
+ RBLNAutoModelForCTC,
76
+ RBLNAutoModelForDepthEstimation,
77
+ RBLNAutoModelForImageClassification,
78
+ RBLNAutoModelForMaskedLM,
79
+ RBLNAutoModelForQuestionAnswering,
80
+ RBLNAutoModelForSeq2SeqLM,
81
+ RBLNAutoModelForSequenceClassification,
82
+ RBLNAutoModelForSpeechSeq2Seq,
83
+ RBLNAutoModelForVision2Seq,
84
+ RBLNBartForConditionalGeneration,
85
+ RBLNBartModel,
86
+ RBLNBertModel,
51
87
  RBLNCLIPTextModel,
52
88
  RBLNCLIPTextModelWithProjection,
89
+ RBLNCLIPVisionModel,
53
90
  RBLNDPTForDepthEstimation,
91
+ RBLNExaoneForCausalLM,
54
92
  RBLNGemmaForCausalLM,
55
93
  RBLNGPT2LMHeadModel,
56
94
  RBLNLlamaForCausalLM,
95
+ RBLNLlavaNextForConditionalGeneration,
57
96
  RBLNMidmLMHeadModel,
58
97
  RBLNMistralForCausalLM,
98
+ RBLNPhiForCausalLM,
99
+ RBLNQwen2ForCausalLM,
100
+ RBLNT5ForConditionalGeneration,
59
101
  RBLNWav2Vec2ForCTC,
60
102
  RBLNWhisperForConditionalGeneration,
61
103
  RBLNXLMRobertaModel,
@@ -21,13 +21,35 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
24
+
25
+ from .auto import (
26
+ RBLNAutoModel,
27
+ RBLNAutoModelForAudioClassification,
28
+ RBLNAutoModelForCausalLM,
29
+ RBLNAutoModelForCTC,
30
+ RBLNAutoModelForDepthEstimation,
31
+ RBLNAutoModelForImageClassification,
32
+ RBLNAutoModelForMaskedLM,
33
+ RBLNAutoModelForQuestionAnswering,
34
+ RBLNAutoModelForSeq2SeqLM,
35
+ RBLNAutoModelForSequenceClassification,
36
+ RBLNAutoModelForSpeechSeq2Seq,
37
+ RBLNAutoModelForVision2Seq,
38
+ )
39
+ from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
40
+ from .bert import RBLNBertModel
41
+ from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
25
42
  from .dpt import RBLNDPTForDepthEstimation
43
+ from .exaone import RBLNExaoneForCausalLM
26
44
  from .gemma import RBLNGemmaForCausalLM
27
45
  from .gpt2 import RBLNGPT2LMHeadModel
28
46
  from .llama import RBLNLlamaForCausalLM
47
+ from .llava_next import RBLNLlavaNextForConditionalGeneration
29
48
  from .midm import RBLNMidmLMHeadModel
30
49
  from .mistral import RBLNMistralForCausalLM
50
+ from .phi import RBLNPhiForCausalLM
51
+ from .qwen2 import RBLNQwen2ForCausalLM
52
+ from .t5 import RBLNT5ForConditionalGeneration
31
53
  from .wav2vec2 import RBLNWav2Vec2ForCTC
32
54
  from .whisper import RBLNWhisperForConditionalGeneration
33
55
  from .xlm_roberta import RBLNXLMRobertaModel
@@ -0,0 +1,14 @@
1
+ from .modeling_auto import (
2
+ RBLNAutoModel,
3
+ RBLNAutoModelForAudioClassification,
4
+ RBLNAutoModelForCausalLM,
5
+ RBLNAutoModelForCTC,
6
+ RBLNAutoModelForDepthEstimation,
7
+ RBLNAutoModelForImageClassification,
8
+ RBLNAutoModelForMaskedLM,
9
+ RBLNAutoModelForQuestionAnswering,
10
+ RBLNAutoModelForSeq2SeqLM,
11
+ RBLNAutoModelForSequenceClassification,
12
+ RBLNAutoModelForSpeechSeq2Seq,
13
+ RBLNAutoModelForVision2Seq,
14
+ )
@@ -0,0 +1,84 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import importlib
25
+
26
+ from transformers import AutoConfig
27
+
28
+
29
+ class _BaseAutoModelClass:
30
+ # Base class for auto models.
31
+ _model_mapping = None
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ raise EnvironmentError(
35
+ f"{self.__class__.__name__} is designed to be instantiated "
36
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
37
+ f"`{self.__class__.__name__}.from_config(config)` methods."
38
+ )
39
+
40
+ @classmethod
41
+ def get_rbln_cls(
42
+ cls,
43
+ model_id,
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+ # kwargs.update({"return_unused_kwargs": True})
48
+ config = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, **kwargs)[0]
49
+
50
+ if len(config.architectures) > 1:
51
+ raise ValueError(
52
+ f"Model with ID '{model_id}' has multiple architectures defined in the configuration: "
53
+ f"{config.architectures}. `_BaseAutoModelClass` require exactly one architecture. "
54
+ )
55
+
56
+ architecture_name = config.architectures[0]
57
+ if architecture_name not in cls._model_mapping.values():
58
+ raise ValueError(
59
+ f"The 'RBLN{architecture_name}' architecture is not supported by `{cls.__name__}.from_pretrained()`."
60
+ "Please use the appropriate class's `from_pretrained()` method to load this model."
61
+ )
62
+
63
+ rbln_class_name = "RBLN" + architecture_name
64
+ module = importlib.import_module("optimum.rbln")
65
+
66
+ try:
67
+ rbln_cls = getattr(module, rbln_class_name)
68
+ except AttributeError as e:
69
+ raise AttributeError(
70
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{model_id}'. "
71
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
72
+ ) from e
73
+
74
+ return rbln_cls
75
+
76
+ @classmethod
77
+ def from_pretrained(
78
+ cls,
79
+ model_id,
80
+ *args,
81
+ **kwargs,
82
+ ):
83
+ rbln_cls = cls.get_rbln_cls(model_id, *args, **kwargs)
84
+ return rbln_cls.from_pretrained(model_id, *args, **kwargs)
@@ -0,0 +1,95 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from transformers.models.auto.modeling_auto import (
25
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
26
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
27
+ MODEL_FOR_CTC_MAPPING_NAMES,
28
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
29
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
30
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
31
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
32
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
33
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
34
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
35
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
36
+ MODEL_MAPPING_NAMES,
37
+ )
38
+
39
+ from .auto_factory import _BaseAutoModelClass
40
+
41
+
42
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
43
+ {
44
+ "midm": "MidmLMHeadModel",
45
+ "exaone": "ExaoneForCausalLM",
46
+ }
47
+ )
48
+
49
+
50
+ class RBLNAutoModel(_BaseAutoModelClass):
51
+ _model_mapping = MODEL_MAPPING_NAMES
52
+
53
+
54
+ class RBLNAutoModelForCTC(_BaseAutoModelClass):
55
+ _model_mapping = MODEL_FOR_CTC_MAPPING_NAMES
56
+
57
+
58
+ class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
59
+ _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
60
+
61
+
62
+ class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
63
+ _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
64
+
65
+
66
+ class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
67
+ _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
68
+
69
+
70
+ class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
71
+ _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
72
+
73
+
74
+ class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
75
+ _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
76
+
77
+
78
+ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
79
+ _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
80
+
81
+
82
+ class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
83
+ _model_mapping = MODEL_FOR_MASKED_LM_MAPPING_NAMES
84
+
85
+
86
+ class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
87
+ _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
88
+
89
+
90
+ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
91
+ _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
92
+
93
+
94
+ class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
95
+ _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
@@ -22,3 +22,4 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .bart_architecture import BartDecoderWrapper, BartEncoderWrapper
25
+ from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel