optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +156 -36
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +85 -75
- optimum/rbln/transformers/__init__.py +79 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +96 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
@@ -18,18 +18,13 @@ import shutil
|
|
18
18
|
from abc import ABC, abstractmethod
|
19
19
|
from pathlib import Path
|
20
20
|
from tempfile import TemporaryDirectory
|
21
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
21
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
22
22
|
|
23
23
|
import rebel
|
24
24
|
import torch
|
25
|
-
from transformers import
|
26
|
-
|
27
|
-
|
28
|
-
GenerationConfig,
|
29
|
-
PretrainedConfig,
|
30
|
-
)
|
31
|
-
|
32
|
-
from .modeling_config import RBLNCompileConfig, RBLNConfig, use_rbln_config
|
25
|
+
from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
|
26
|
+
|
27
|
+
from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig
|
33
28
|
from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
|
34
29
|
from .utils.logging import get_logger
|
35
30
|
from .utils.runtime_utils import UnavailableRuntime
|
@@ -47,6 +42,10 @@ class PreTrainedModel(ABC): # noqa: F811
|
|
47
42
|
pass
|
48
43
|
|
49
44
|
|
45
|
+
class RBLNBaseModelConfig(RBLNModelConfig):
|
46
|
+
pass
|
47
|
+
|
48
|
+
|
50
49
|
class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
51
50
|
"""
|
52
51
|
An abstract base class for compiling, loading, and saving neural network models from the huggingface
|
@@ -85,15 +84,17 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
85
84
|
model_type = "rbln_model"
|
86
85
|
auto_model_class = AutoModel
|
87
86
|
config_class = AutoConfig
|
87
|
+
|
88
88
|
config_name = "config.json"
|
89
89
|
hf_library_name = "transformers"
|
90
90
|
_hf_class = None
|
91
|
+
_rbln_config_class = None
|
91
92
|
|
92
93
|
def __init__(
|
93
94
|
self,
|
94
95
|
models: List[rebel.Runtime],
|
95
96
|
config: "PretrainedConfig",
|
96
|
-
rbln_config:
|
97
|
+
rbln_config: RBLNModelConfig,
|
97
98
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
98
99
|
subfolder: str = "",
|
99
100
|
rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
|
@@ -103,6 +104,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
103
104
|
self.model = models
|
104
105
|
self.config = config
|
105
106
|
self.rbln_config = rbln_config
|
107
|
+
if not rbln_config.is_frozen():
|
108
|
+
raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
|
109
|
+
|
106
110
|
self.compiled_models = rbln_compiled_models
|
107
111
|
|
108
112
|
# Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
|
@@ -118,7 +122,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
118
122
|
else:
|
119
123
|
self.generation_config = None
|
120
124
|
|
121
|
-
# self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
122
125
|
if self.generation_config is not None:
|
123
126
|
self.generation_config.use_cache = True
|
124
127
|
|
@@ -181,11 +184,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
181
184
|
return rbln_compiled_models
|
182
185
|
|
183
186
|
@classmethod
|
184
|
-
@use_rbln_config
|
185
187
|
def _from_pretrained(
|
186
188
|
cls,
|
187
189
|
model_id: Union[str, Path],
|
188
|
-
config: "PretrainedConfig" = None,
|
190
|
+
config: Optional["PretrainedConfig"] = None,
|
189
191
|
use_auth_token: Optional[Union[bool, str]] = None,
|
190
192
|
revision: Optional[str] = None,
|
191
193
|
force_download: bool = False,
|
@@ -195,17 +197,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
195
197
|
trust_remote_code: bool = False,
|
196
198
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
197
199
|
# passed from compile function
|
198
|
-
rbln_config: Optional[
|
200
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
199
201
|
rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
|
200
202
|
rbln_submodules: List["RBLNBaseModel"] = [],
|
201
203
|
**kwargs,
|
202
204
|
) -> "RBLNBaseModel":
|
203
|
-
|
204
|
-
|
205
|
-
if not from_export_method:
|
206
|
-
# from compiled dir
|
207
|
-
rbln_kwargs = rbln_config or {}
|
208
|
-
|
205
|
+
if rbln_compiled_models is None:
|
209
206
|
model_path_subfolder = cls._load_compiled_model_dir(
|
210
207
|
model_id=model_id,
|
211
208
|
use_auth_token=use_auth_token,
|
@@ -216,16 +213,33 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
216
213
|
local_files_only=local_files_only,
|
217
214
|
)
|
218
215
|
|
219
|
-
rbln_config
|
220
|
-
|
216
|
+
if isinstance(rbln_config, dict):
|
217
|
+
rbln_config_as_kwargs = {f"rbln_{key}": value for key, value in rbln_config.items()}
|
218
|
+
kwargs.update(rbln_config_as_kwargs)
|
219
|
+
elif isinstance(rbln_config, RBLNModelConfig) and rbln_config.rbln_model_cls_name != cls.__name__:
|
220
|
+
raise ValueError(
|
221
|
+
f"Cannot use the passed rbln_config. Its model class name ({rbln_config.rbln_model_cls_name}) "
|
222
|
+
f"does not match the expected model class name ({cls.__name__})."
|
223
|
+
)
|
224
|
+
|
225
|
+
rbln_config, kwargs = RBLNAutoConfig.load(
|
226
|
+
model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
|
227
|
+
)
|
221
228
|
|
222
|
-
if rbln_config.
|
229
|
+
if rbln_config.rbln_model_cls_name != cls.__name__:
|
223
230
|
raise NameError(
|
224
231
|
f"Cannot load the model. The model was originally compiled using "
|
225
|
-
f"{rbln_config.
|
232
|
+
f"{rbln_config.rbln_model_cls_name}, but you are trying to load it with {cls.__name__}."
|
226
233
|
"Please use the same model class that was used during compilation."
|
227
234
|
)
|
228
235
|
|
236
|
+
if len(cls._rbln_submodules) > 0:
|
237
|
+
rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
|
238
|
+
else:
|
239
|
+
rbln_submodules = []
|
240
|
+
|
241
|
+
rbln_config.freeze()
|
242
|
+
|
229
243
|
if config is None:
|
230
244
|
if cls.hf_library_name == "transformers":
|
231
245
|
config = AutoConfig.from_pretrained(
|
@@ -258,15 +272,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
258
272
|
|
259
273
|
rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
|
260
274
|
|
261
|
-
if len(cls._rbln_submodules) > 0:
|
262
|
-
rbln_submodules = cls._load_submodules(
|
263
|
-
model_save_dir=model_id,
|
264
|
-
rbln_kwargs=rbln_kwargs,
|
265
|
-
**kwargs,
|
266
|
-
)
|
267
|
-
else:
|
268
|
-
rbln_submodules = []
|
269
|
-
|
270
275
|
if subfolder != "":
|
271
276
|
model_save_dir = Path(model_path_subfolder).absolute().parent
|
272
277
|
else:
|
@@ -286,7 +291,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
286
291
|
def _from_compiled_models(
|
287
292
|
cls,
|
288
293
|
rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
|
289
|
-
rbln_config:
|
294
|
+
rbln_config: RBLNModelConfig,
|
290
295
|
config: "PretrainedConfig",
|
291
296
|
model_save_dir: Union[Path, str],
|
292
297
|
subfolder: Union[Path, str],
|
@@ -303,7 +308,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
303
308
|
# create runtimes only if `rbln_create_runtimes` is enabled
|
304
309
|
try:
|
305
310
|
models = (
|
306
|
-
cls._create_runtimes(rbln_compiled_models, rbln_config
|
311
|
+
cls._create_runtimes(rbln_compiled_models, rbln_config)
|
307
312
|
if rbln_config.create_runtimes
|
308
313
|
else UnavailableRuntime()
|
309
314
|
)
|
@@ -326,38 +331,31 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
326
331
|
)
|
327
332
|
|
328
333
|
@classmethod
|
329
|
-
|
330
|
-
def _export(
|
331
|
-
cls,
|
332
|
-
model_id: Union[str, Path],
|
333
|
-
rbln_config: Optional[Dict[str, Any]] = None,
|
334
|
-
**kwargs,
|
335
|
-
) -> "RBLNBaseModel":
|
334
|
+
def _export(cls, model_id: Union[str, Path], **kwargs) -> "RBLNBaseModel":
|
336
335
|
subfolder = kwargs.get("subfolder", "")
|
337
336
|
model_save_dir = kwargs.pop("model_save_dir", None)
|
338
337
|
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
rbln_kwargs=rbln_kwargs,
|
343
|
-
**kwargs,
|
344
|
-
)
|
338
|
+
rbln_config, kwargs = cls.prepare_rbln_config(**kwargs)
|
339
|
+
|
340
|
+
model: "PreTrainedModel" = cls.get_pytorch_model(model_id=model_id, rbln_config=rbln_config, **kwargs)
|
345
341
|
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
|
346
342
|
return cls.from_model(
|
347
|
-
model,
|
348
|
-
rbln_config=rbln_config,
|
349
|
-
preprocessors=preprocessors,
|
350
|
-
model_save_dir=model_save_dir,
|
351
|
-
**kwargs,
|
343
|
+
model, preprocessors=preprocessors, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
|
352
344
|
)
|
353
345
|
|
354
346
|
@classmethod
|
355
|
-
def
|
356
|
-
cls,
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
347
|
+
def prepare_rbln_config(
|
348
|
+
cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
|
349
|
+
) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
|
350
|
+
"""
|
351
|
+
Extract rbln-config from kwargs and convert it to RBLNModelConfig.
|
352
|
+
"""
|
353
|
+
config_cls = cls.get_rbln_config_class()
|
354
|
+
rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
|
355
|
+
return rbln_config, kwargs
|
356
|
+
|
357
|
+
@classmethod
|
358
|
+
def from_pretrained(cls, model_id: Union[str, Path], export: bool = False, **kwargs) -> "RBLNBaseModel":
|
361
359
|
if isinstance(model_id, Path):
|
362
360
|
model_id = model_id.as_posix()
|
363
361
|
from_pretrained_method = cls._export if export else cls._from_pretrained
|
@@ -376,17 +374,14 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
376
374
|
return compiled_model
|
377
375
|
|
378
376
|
@classmethod
|
379
|
-
def
|
380
|
-
cls
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
Note that batch_size should be specified with proper input_info.
|
388
|
-
"""
|
389
|
-
rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
|
377
|
+
def update_rbln_config(cls, **others) -> RBLNModelConfig:
|
378
|
+
rbln_config = cls._update_rbln_config(**others)
|
379
|
+
rbln_config.freeze()
|
380
|
+
if rbln_config.rbln_model_cls_name != cls.__name__:
|
381
|
+
raise NameError(
|
382
|
+
f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
|
383
|
+
"This is an internal error. Please report it to the developers."
|
384
|
+
)
|
390
385
|
return rbln_config
|
391
386
|
|
392
387
|
@classmethod
|
@@ -406,6 +401,22 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
406
401
|
cls._hf_class = getattr(library, hf_cls_name, None)
|
407
402
|
return cls._hf_class
|
408
403
|
|
404
|
+
@classmethod
|
405
|
+
def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
|
406
|
+
"""
|
407
|
+
Lazily loads and caches the corresponding RBLN model config class.
|
408
|
+
"""
|
409
|
+
if cls._rbln_config_class is None:
|
410
|
+
rbln_config_class_name = cls.__name__ + "Config"
|
411
|
+
library = importlib.import_module("optimum.rbln")
|
412
|
+
cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
|
413
|
+
if cls._rbln_config_class is None:
|
414
|
+
raise ValueError(
|
415
|
+
f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
|
416
|
+
"Please report it to the developers."
|
417
|
+
)
|
418
|
+
return cls._rbln_config_class
|
419
|
+
|
409
420
|
def can_generate(self):
|
410
421
|
return False
|
411
422
|
|
@@ -516,7 +527,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
516
527
|
|
517
528
|
@classmethod
|
518
529
|
@abstractmethod
|
519
|
-
def
|
530
|
+
def _update_rbln_config(cls, **rbln_config_kwargs) -> RBLNModelConfig:
|
520
531
|
pass
|
521
532
|
|
522
533
|
@classmethod
|
@@ -524,8 +535,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
524
535
|
def _create_runtimes(
|
525
536
|
cls,
|
526
537
|
compiled_models: List[rebel.RBLNCompiledModel],
|
527
|
-
|
528
|
-
activate_profiler: Optional[bool] = None,
|
538
|
+
rbln_config: RBLNModelConfig,
|
529
539
|
) -> List[rebel.Runtime]:
|
530
540
|
# compiled_models -> runtimes
|
531
541
|
pass
|
@@ -537,11 +547,11 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
537
547
|
|
538
548
|
@classmethod
|
539
549
|
@abstractmethod
|
540
|
-
@use_rbln_config
|
541
550
|
def from_model(
|
542
551
|
cls,
|
543
552
|
model: "PreTrainedModel",
|
544
|
-
|
553
|
+
config: Optional[PretrainedConfig] = None,
|
554
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
545
555
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
546
556
|
subfolder: str = "",
|
547
557
|
**kwargs,
|
@@ -18,7 +18,15 @@ from transformers.utils import _LazyModule
|
|
18
18
|
|
19
19
|
|
20
20
|
_import_structure = {
|
21
|
-
"
|
21
|
+
"configuration_alias": [
|
22
|
+
"RBLNASTForAudioClassificationConfig",
|
23
|
+
"RBLNDistilBertForQuestionAnsweringConfig",
|
24
|
+
"RBLNResNetForImageClassificationConfig",
|
25
|
+
"RBLNXLMRobertaForSequenceClassificationConfig",
|
26
|
+
"RBLNRobertaForSequenceClassificationConfig",
|
27
|
+
"RBLNRobertaForMaskedLMConfig",
|
28
|
+
"RBLNViTForImageClassificationConfig",
|
29
|
+
],
|
22
30
|
"models": [
|
23
31
|
"RBLNAutoModel",
|
24
32
|
"RBLNAutoModelForAudioClassification",
|
@@ -33,30 +41,58 @@ _import_structure = {
|
|
33
41
|
"RBLNAutoModelForSpeechSeq2Seq",
|
34
42
|
"RBLNAutoModelForVision2Seq",
|
35
43
|
"RBLNBartForConditionalGeneration",
|
44
|
+
"RBLNBartForConditionalGenerationConfig",
|
36
45
|
"RBLNBartModel",
|
37
|
-
"
|
46
|
+
"RBLNBartModelConfig",
|
38
47
|
"RBLNBertForMaskedLM",
|
48
|
+
"RBLNBertForMaskedLMConfig",
|
39
49
|
"RBLNBertForQuestionAnswering",
|
50
|
+
"RBLNBertForQuestionAnsweringConfig",
|
51
|
+
"RBLNBertModel",
|
52
|
+
"RBLNBertModelConfig",
|
40
53
|
"RBLNCLIPTextModel",
|
54
|
+
"RBLNCLIPTextModelConfig",
|
41
55
|
"RBLNCLIPTextModelWithProjection",
|
56
|
+
"RBLNCLIPTextModelWithProjectionConfig",
|
42
57
|
"RBLNCLIPVisionModel",
|
58
|
+
"RBLNCLIPVisionModelConfig",
|
43
59
|
"RBLNCLIPVisionModelWithProjection",
|
60
|
+
"RBLNCLIPVisionModelWithProjectionConfig",
|
61
|
+
"RBLNDecoderOnlyModelForCausalLM",
|
62
|
+
"RBLNDecoderOnlyModelForCausalLMConfig",
|
44
63
|
"RBLNDPTForDepthEstimation",
|
64
|
+
"RBLNDPTForDepthEstimationConfig",
|
45
65
|
"RBLNExaoneForCausalLM",
|
66
|
+
"RBLNExaoneForCausalLMConfig",
|
46
67
|
"RBLNGemmaForCausalLM",
|
68
|
+
"RBLNGemmaForCausalLMConfig",
|
47
69
|
"RBLNGPT2LMHeadModel",
|
48
|
-
"
|
49
|
-
"RBLNWav2Vec2ForCTC",
|
50
|
-
"RBLNWhisperForConditionalGeneration",
|
70
|
+
"RBLNGPT2LMHeadModelConfig",
|
51
71
|
"RBLNLlamaForCausalLM",
|
72
|
+
"RBLNLlamaForCausalLMConfig",
|
73
|
+
"RBLNLlavaNextForConditionalGeneration",
|
74
|
+
"RBLNLlavaNextForConditionalGenerationConfig",
|
75
|
+
"RBLNMidmLMHeadModel",
|
76
|
+
"RBLNMidmLMHeadModelConfig",
|
77
|
+
"RBLNMistralForCausalLM",
|
78
|
+
"RBLNMistralForCausalLMConfig",
|
52
79
|
"RBLNPhiForCausalLM",
|
80
|
+
"RBLNPhiForCausalLMConfig",
|
81
|
+
"RBLNQwen2ForCausalLM",
|
82
|
+
"RBLNQwen2ForCausalLMConfig",
|
53
83
|
"RBLNT5EncoderModel",
|
84
|
+
"RBLNT5EncoderModelConfig",
|
54
85
|
"RBLNT5ForConditionalGeneration",
|
86
|
+
"RBLNT5ForConditionalGenerationConfig",
|
87
|
+
"RBLNWav2Vec2ForCTC",
|
88
|
+
"RBLNWav2Vec2ForCTCConfig",
|
89
|
+
"RBLNWhisperForConditionalGeneration",
|
90
|
+
"RBLNWhisperForConditionalGenerationConfig",
|
55
91
|
"RBLNTimeSeriesTransformerForPrediction",
|
92
|
+
"RBLNTimeSeriesTransformerForPredictionConfig",
|
56
93
|
"RBLNLlavaNextForConditionalGeneration",
|
57
|
-
"RBLNMidmLMHeadModel",
|
58
94
|
"RBLNXLMRobertaModel",
|
59
|
-
"
|
95
|
+
"RBLNXLMRobertaModelConfig",
|
60
96
|
],
|
61
97
|
"modeling_alias": [
|
62
98
|
"RBLNASTForAudioClassification",
|
@@ -70,7 +106,15 @@ _import_structure = {
|
|
70
106
|
}
|
71
107
|
|
72
108
|
if TYPE_CHECKING:
|
73
|
-
from .
|
109
|
+
from .configuration_alias import (
|
110
|
+
RBLNASTForAudioClassificationConfig,
|
111
|
+
RBLNDistilBertForQuestionAnsweringConfig,
|
112
|
+
RBLNResNetForImageClassificationConfig,
|
113
|
+
RBLNRobertaForMaskedLMConfig,
|
114
|
+
RBLNRobertaForSequenceClassificationConfig,
|
115
|
+
RBLNViTForImageClassificationConfig,
|
116
|
+
RBLNXLMRobertaForSequenceClassificationConfig,
|
117
|
+
)
|
74
118
|
from .modeling_alias import (
|
75
119
|
RBLNASTForAudioClassification,
|
76
120
|
RBLNDistilBertForQuestionAnswering,
|
@@ -94,30 +138,57 @@ if TYPE_CHECKING:
|
|
94
138
|
RBLNAutoModelForSpeechSeq2Seq,
|
95
139
|
RBLNAutoModelForVision2Seq,
|
96
140
|
RBLNBartForConditionalGeneration,
|
141
|
+
RBLNBartForConditionalGenerationConfig,
|
97
142
|
RBLNBartModel,
|
143
|
+
RBLNBartModelConfig,
|
98
144
|
RBLNBertForMaskedLM,
|
145
|
+
RBLNBertForMaskedLMConfig,
|
99
146
|
RBLNBertForQuestionAnswering,
|
147
|
+
RBLNBertForQuestionAnsweringConfig,
|
100
148
|
RBLNBertModel,
|
149
|
+
RBLNBertModelConfig,
|
101
150
|
RBLNCLIPTextModel,
|
151
|
+
RBLNCLIPTextModelConfig,
|
102
152
|
RBLNCLIPTextModelWithProjection,
|
153
|
+
RBLNCLIPTextModelWithProjectionConfig,
|
103
154
|
RBLNCLIPVisionModel,
|
155
|
+
RBLNCLIPVisionModelConfig,
|
104
156
|
RBLNCLIPVisionModelWithProjection,
|
157
|
+
RBLNCLIPVisionModelWithProjectionConfig,
|
158
|
+
RBLNDecoderOnlyModelForCausalLM,
|
159
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
105
160
|
RBLNDPTForDepthEstimation,
|
161
|
+
RBLNDPTForDepthEstimationConfig,
|
106
162
|
RBLNExaoneForCausalLM,
|
163
|
+
RBLNExaoneForCausalLMConfig,
|
107
164
|
RBLNGemmaForCausalLM,
|
165
|
+
RBLNGemmaForCausalLMConfig,
|
108
166
|
RBLNGPT2LMHeadModel,
|
167
|
+
RBLNGPT2LMHeadModelConfig,
|
109
168
|
RBLNLlamaForCausalLM,
|
169
|
+
RBLNLlamaForCausalLMConfig,
|
110
170
|
RBLNLlavaNextForConditionalGeneration,
|
171
|
+
RBLNLlavaNextForConditionalGenerationConfig,
|
111
172
|
RBLNMidmLMHeadModel,
|
173
|
+
RBLNMidmLMHeadModelConfig,
|
112
174
|
RBLNMistralForCausalLM,
|
175
|
+
RBLNMistralForCausalLMConfig,
|
113
176
|
RBLNPhiForCausalLM,
|
177
|
+
RBLNPhiForCausalLMConfig,
|
114
178
|
RBLNQwen2ForCausalLM,
|
179
|
+
RBLNQwen2ForCausalLMConfig,
|
115
180
|
RBLNT5EncoderModel,
|
181
|
+
RBLNT5EncoderModelConfig,
|
116
182
|
RBLNT5ForConditionalGeneration,
|
183
|
+
RBLNT5ForConditionalGenerationConfig,
|
117
184
|
RBLNTimeSeriesTransformerForPrediction,
|
185
|
+
RBLNTimeSeriesTransformerForPredictionConfig,
|
118
186
|
RBLNWav2Vec2ForCTC,
|
187
|
+
RBLNWav2Vec2ForCTCConfig,
|
119
188
|
RBLNWhisperForConditionalGeneration,
|
189
|
+
RBLNWhisperForConditionalGenerationConfig,
|
120
190
|
RBLNXLMRobertaModel,
|
191
|
+
RBLNXLMRobertaModelConfig,
|
121
192
|
)
|
122
193
|
else:
|
123
194
|
import sys
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_generic import (
|
16
|
+
RBLNModelForAudioClassificationConfig,
|
17
|
+
RBLNModelForImageClassificationConfig,
|
18
|
+
RBLNModelForMaskedLMConfig,
|
19
|
+
RBLNModelForQuestionAnsweringConfig,
|
20
|
+
RBLNModelForSequenceClassificationConfig,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNASTForAudioClassificationConfig(RBLNModelForAudioClassificationConfig):
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
49
|
+
pass
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, Optional, Tuple, Union
|
16
|
+
|
17
|
+
from ..configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class _RBLNTransformerEncoderConfig(RBLNModelConfig):
|
21
|
+
rbln_model_input_names: Optional[List[str]] = None
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
max_seq_len: Optional[int] = None,
|
26
|
+
batch_size: Optional[int] = None,
|
27
|
+
model_input_names: Optional[List[str]] = None,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Args:
|
32
|
+
max_seq_len (Optional[int]): Maximum sequence length supported by the model.
|
33
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
34
|
+
model_input_names (Optional[List[str]]): Names of the input tensors for the model.
|
35
|
+
Defaults to class-specific rbln_model_input_names if not provided.
|
36
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If batch_size is not a positive integer.
|
40
|
+
"""
|
41
|
+
super().__init__(**kwargs)
|
42
|
+
self.max_seq_len = max_seq_len
|
43
|
+
self.batch_size = batch_size or 1
|
44
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
45
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
46
|
+
|
47
|
+
self.model_input_names = model_input_names or self.rbln_model_input_names
|
48
|
+
|
49
|
+
|
50
|
+
class _RBLNImageModelConfig(RBLNModelConfig):
|
51
|
+
def __init__(
|
52
|
+
self, image_size: Optional[Union[int, Tuple[int, int]]] = None, batch_size: Optional[int] = None, **kwargs
|
53
|
+
):
|
54
|
+
"""
|
55
|
+
Args:
|
56
|
+
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
57
|
+
Can be an integer for square images or a tuple (height, width).
|
58
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
59
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ValueError: If batch_size is not a positive integer.
|
63
|
+
"""
|
64
|
+
super().__init__(**kwargs)
|
65
|
+
self.image_size = image_size
|
66
|
+
self.batch_size = batch_size or 1
|
67
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
68
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
69
|
+
|
70
|
+
@property
|
71
|
+
def image_width(self):
|
72
|
+
if isinstance(self.image_size, int):
|
73
|
+
return self.image_size
|
74
|
+
elif isinstance(self.image_size, (list, tuple)):
|
75
|
+
return self.image_size[1]
|
76
|
+
else:
|
77
|
+
return self.image_size["width"]
|
78
|
+
|
79
|
+
@property
|
80
|
+
def image_height(self):
|
81
|
+
if isinstance(self.image_size, int):
|
82
|
+
return self.image_size
|
83
|
+
elif isinstance(self.image_size, (list, tuple)):
|
84
|
+
return self.image_size[0]
|
85
|
+
else:
|
86
|
+
return self.image_size["height"]
|
87
|
+
|
88
|
+
|
89
|
+
class RBLNModelForQuestionAnsweringConfig(_RBLNTransformerEncoderConfig):
|
90
|
+
pass
|
91
|
+
|
92
|
+
|
93
|
+
class RBLNModelForSequenceClassificationConfig(_RBLNTransformerEncoderConfig):
|
94
|
+
pass
|
95
|
+
|
96
|
+
|
97
|
+
class RBLNModelForMaskedLMConfig(_RBLNTransformerEncoderConfig):
|
98
|
+
pass
|
99
|
+
|
100
|
+
|
101
|
+
class RBLNModelForTextEncodingConfig(_RBLNTransformerEncoderConfig):
|
102
|
+
pass
|
103
|
+
|
104
|
+
|
105
|
+
# FIXME : Appropriate name ?
|
106
|
+
class RBLNTransformerEncoderForFeatureExtractionConfig(_RBLNTransformerEncoderConfig):
|
107
|
+
pass
|
108
|
+
|
109
|
+
|
110
|
+
class RBLNModelForImageClassificationConfig(_RBLNImageModelConfig):
|
111
|
+
pass
|
112
|
+
|
113
|
+
|
114
|
+
class RBLNModelForDepthEstimationConfig(_RBLNImageModelConfig):
|
115
|
+
pass
|
116
|
+
|
117
|
+
|
118
|
+
class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
batch_size: Optional[int] = None,
|
122
|
+
max_length: Optional[int] = None,
|
123
|
+
num_mel_bins: Optional[int] = None,
|
124
|
+
**kwargs,
|
125
|
+
):
|
126
|
+
"""
|
127
|
+
Args:
|
128
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
129
|
+
max_length (Optional[int]): Maximum length of the audio input in time dimension.
|
130
|
+
num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
|
131
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
132
|
+
|
133
|
+
Raises:
|
134
|
+
ValueError: If batch_size is not a positive integer.
|
135
|
+
"""
|
136
|
+
super().__init__(**kwargs)
|
137
|
+
self.batch_size = batch_size or 1
|
138
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
139
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
140
|
+
|
141
|
+
self.max_length = max_length
|
142
|
+
self.num_mel_bins = num_mel_bins
|