optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -18,18 +18,13 @@ import shutil
18
18
  from abc import ABC, abstractmethod
19
19
  from pathlib import Path
20
20
  from tempfile import TemporaryDirectory
21
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
21
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
22
22
 
23
23
  import rebel
24
24
  import torch
25
- from transformers import (
26
- AutoConfig,
27
- AutoModel,
28
- GenerationConfig,
29
- PretrainedConfig,
30
- )
31
-
32
- from .modeling_config import RBLNCompileConfig, RBLNConfig, use_rbln_config
25
+ from transformers import AutoConfig, AutoModel, GenerationConfig, PretrainedConfig
26
+
27
+ from .configuration_utils import RBLNAutoConfig, RBLNCompileConfig, RBLNModelConfig
33
28
  from .utils.hub import PushToHubMixin, pull_compiled_model_from_hub, validate_files
34
29
  from .utils.logging import get_logger
35
30
  from .utils.runtime_utils import UnavailableRuntime
@@ -47,6 +42,10 @@ class PreTrainedModel(ABC): # noqa: F811
47
42
  pass
48
43
 
49
44
 
45
+ class RBLNBaseModelConfig(RBLNModelConfig):
46
+ pass
47
+
48
+
50
49
  class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
51
50
  """
52
51
  An abstract base class for compiling, loading, and saving neural network models from the huggingface
@@ -85,15 +84,17 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
85
84
  model_type = "rbln_model"
86
85
  auto_model_class = AutoModel
87
86
  config_class = AutoConfig
87
+
88
88
  config_name = "config.json"
89
89
  hf_library_name = "transformers"
90
90
  _hf_class = None
91
+ _rbln_config_class = None
91
92
 
92
93
  def __init__(
93
94
  self,
94
95
  models: List[rebel.Runtime],
95
96
  config: "PretrainedConfig",
96
- rbln_config: RBLNConfig,
97
+ rbln_config: RBLNModelConfig,
97
98
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
98
99
  subfolder: str = "",
99
100
  rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
@@ -103,6 +104,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
103
104
  self.model = models
104
105
  self.config = config
105
106
  self.rbln_config = rbln_config
107
+ if not rbln_config.is_frozen():
108
+ raise RuntimeError("`rbln_config` must be frozen. Please call `rbln_config.freeze()` first.")
109
+
106
110
  self.compiled_models = rbln_compiled_models
107
111
 
108
112
  # Registers the RBLN classes into the transformers AutoModel classes to avoid warnings when creating
@@ -118,7 +122,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
118
122
  else:
119
123
  self.generation_config = None
120
124
 
121
- # self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
122
125
  if self.generation_config is not None:
123
126
  self.generation_config.use_cache = True
124
127
 
@@ -181,11 +184,10 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
181
184
  return rbln_compiled_models
182
185
 
183
186
  @classmethod
184
- @use_rbln_config
185
187
  def _from_pretrained(
186
188
  cls,
187
189
  model_id: Union[str, Path],
188
- config: "PretrainedConfig" = None,
190
+ config: Optional["PretrainedConfig"] = None,
189
191
  use_auth_token: Optional[Union[bool, str]] = None,
190
192
  revision: Optional[str] = None,
191
193
  force_download: bool = False,
@@ -195,17 +197,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
195
197
  trust_remote_code: bool = False,
196
198
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
197
199
  # passed from compile function
198
- rbln_config: Optional[RBLNConfig] = None,
200
+ rbln_config: Optional[RBLNModelConfig] = None,
199
201
  rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
200
202
  rbln_submodules: List["RBLNBaseModel"] = [],
201
203
  **kwargs,
202
204
  ) -> "RBLNBaseModel":
203
- from_export_method = isinstance(rbln_config, RBLNConfig) and rbln_compiled_models is not None
204
-
205
- if not from_export_method:
206
- # from compiled dir
207
- rbln_kwargs = rbln_config or {}
208
-
205
+ if rbln_compiled_models is None:
209
206
  model_path_subfolder = cls._load_compiled_model_dir(
210
207
  model_id=model_id,
211
208
  use_auth_token=use_auth_token,
@@ -216,16 +213,33 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
216
213
  local_files_only=local_files_only,
217
214
  )
218
215
 
219
- rbln_config = RBLNConfig.load(model_path_subfolder)
220
- rbln_config.update_runtime_cfg(rbln_kwargs)
216
+ if isinstance(rbln_config, dict):
217
+ rbln_config_as_kwargs = {f"rbln_{key}": value for key, value in rbln_config.items()}
218
+ kwargs.update(rbln_config_as_kwargs)
219
+ elif isinstance(rbln_config, RBLNModelConfig) and rbln_config.rbln_model_cls_name != cls.__name__:
220
+ raise ValueError(
221
+ f"Cannot use the passed rbln_config. Its model class name ({rbln_config.rbln_model_cls_name}) "
222
+ f"does not match the expected model class name ({cls.__name__})."
223
+ )
224
+
225
+ rbln_config, kwargs = RBLNAutoConfig.load(
226
+ model_path_subfolder, passed_rbln_config=rbln_config, kwargs=kwargs, return_unused_kwargs=True
227
+ )
221
228
 
222
- if rbln_config.meta["cls"] != cls.__name__:
229
+ if rbln_config.rbln_model_cls_name != cls.__name__:
223
230
  raise NameError(
224
231
  f"Cannot load the model. The model was originally compiled using "
225
- f"{rbln_config.meta['cls']}, but you are trying to load it with {cls.__name__}."
232
+ f"{rbln_config.rbln_model_cls_name}, but you are trying to load it with {cls.__name__}."
226
233
  "Please use the same model class that was used during compilation."
227
234
  )
228
235
 
236
+ if len(cls._rbln_submodules) > 0:
237
+ rbln_submodules = cls._load_submodules(model_save_dir=model_id, rbln_config=rbln_config, **kwargs)
238
+ else:
239
+ rbln_submodules = []
240
+
241
+ rbln_config.freeze()
242
+
229
243
  if config is None:
230
244
  if cls.hf_library_name == "transformers":
231
245
  config = AutoConfig.from_pretrained(
@@ -258,15 +272,6 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
258
272
 
259
273
  rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
260
274
 
261
- if len(cls._rbln_submodules) > 0:
262
- rbln_submodules = cls._load_submodules(
263
- model_save_dir=model_id,
264
- rbln_kwargs=rbln_kwargs,
265
- **kwargs,
266
- )
267
- else:
268
- rbln_submodules = []
269
-
270
275
  if subfolder != "":
271
276
  model_save_dir = Path(model_path_subfolder).absolute().parent
272
277
  else:
@@ -286,7 +291,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
286
291
  def _from_compiled_models(
287
292
  cls,
288
293
  rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
289
- rbln_config: RBLNConfig,
294
+ rbln_config: RBLNModelConfig,
290
295
  config: "PretrainedConfig",
291
296
  model_save_dir: Union[Path, str],
292
297
  subfolder: Union[Path, str],
@@ -303,7 +308,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
303
308
  # create runtimes only if `rbln_create_runtimes` is enabled
304
309
  try:
305
310
  models = (
306
- cls._create_runtimes(rbln_compiled_models, rbln_config.device_map, rbln_config.activate_profiler)
311
+ cls._create_runtimes(rbln_compiled_models, rbln_config)
307
312
  if rbln_config.create_runtimes
308
313
  else UnavailableRuntime()
309
314
  )
@@ -326,38 +331,31 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
326
331
  )
327
332
 
328
333
  @classmethod
329
- @use_rbln_config
330
- def _export(
331
- cls,
332
- model_id: Union[str, Path],
333
- rbln_config: Optional[Dict[str, Any]] = None,
334
- **kwargs,
335
- ) -> "RBLNBaseModel":
334
+ def _export(cls, model_id: Union[str, Path], **kwargs) -> "RBLNBaseModel":
336
335
  subfolder = kwargs.get("subfolder", "")
337
336
  model_save_dir = kwargs.pop("model_save_dir", None)
338
337
 
339
- rbln_kwargs = rbln_config
340
- model: "PreTrainedModel" = cls.get_pytorch_model(
341
- model_id=model_id,
342
- rbln_kwargs=rbln_kwargs,
343
- **kwargs,
344
- )
338
+ rbln_config, kwargs = cls.prepare_rbln_config(**kwargs)
339
+
340
+ model: "PreTrainedModel" = cls.get_pytorch_model(model_id=model_id, rbln_config=rbln_config, **kwargs)
345
341
  preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
346
342
  return cls.from_model(
347
- model,
348
- rbln_config=rbln_config,
349
- preprocessors=preprocessors,
350
- model_save_dir=model_save_dir,
351
- **kwargs,
343
+ model, preprocessors=preprocessors, model_save_dir=model_save_dir, rbln_config=rbln_config, **kwargs
352
344
  )
353
345
 
354
346
  @classmethod
355
- def from_pretrained(
356
- cls,
357
- model_id: Union[str, Path],
358
- export: bool = False,
359
- **kwargs,
360
- ) -> "RBLNBaseModel":
347
+ def prepare_rbln_config(
348
+ cls, rbln_config: Optional[Union[Dict[str, Any], RBLNModelConfig]] = None, **kwargs
349
+ ) -> Tuple[RBLNModelConfig, Dict[str, Any]]:
350
+ """
351
+ Extract rbln-config from kwargs and convert it to RBLNModelConfig.
352
+ """
353
+ config_cls = cls.get_rbln_config_class()
354
+ rbln_config, kwargs = config_cls.initialize_from_kwargs(rbln_config, **kwargs)
355
+ return rbln_config, kwargs
356
+
357
+ @classmethod
358
+ def from_pretrained(cls, model_id: Union[str, Path], export: bool = False, **kwargs) -> "RBLNBaseModel":
361
359
  if isinstance(model_id, Path):
362
360
  model_id = model_id.as_posix()
363
361
  from_pretrained_method = cls._export if export else cls._from_pretrained
@@ -376,17 +374,14 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
376
374
  return compiled_model
377
375
 
378
376
  @classmethod
379
- def get_rbln_config(
380
- cls,
381
- rbln_kwargs: Dict[str, Any],
382
- **others,
383
- ) -> RBLNConfig:
384
- """
385
- Make default rbln-config for the model.
386
- kwargs for overriding model's config can be accepted.
387
- Note that batch_size should be specified with proper input_info.
388
- """
389
- rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
377
+ def update_rbln_config(cls, **others) -> RBLNModelConfig:
378
+ rbln_config = cls._update_rbln_config(**others)
379
+ rbln_config.freeze()
380
+ if rbln_config.rbln_model_cls_name != cls.__name__:
381
+ raise NameError(
382
+ f"Cannot get the rbln config. {cls.__name__} is not the same as {rbln_config.rbln_model_cls_name}. "
383
+ "This is an internal error. Please report it to the developers."
384
+ )
390
385
  return rbln_config
391
386
 
392
387
  @classmethod
@@ -406,6 +401,22 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
406
401
  cls._hf_class = getattr(library, hf_cls_name, None)
407
402
  return cls._hf_class
408
403
 
404
+ @classmethod
405
+ def get_rbln_config_class(cls) -> Type[RBLNModelConfig]:
406
+ """
407
+ Lazily loads and caches the corresponding RBLN model config class.
408
+ """
409
+ if cls._rbln_config_class is None:
410
+ rbln_config_class_name = cls.__name__ + "Config"
411
+ library = importlib.import_module("optimum.rbln")
412
+ cls._rbln_config_class = getattr(library, rbln_config_class_name, None)
413
+ if cls._rbln_config_class is None:
414
+ raise ValueError(
415
+ f"RBLN config class {rbln_config_class_name} not found. This is an internal error. "
416
+ "Please report it to the developers."
417
+ )
418
+ return cls._rbln_config_class
419
+
409
420
  def can_generate(self):
410
421
  return False
411
422
 
@@ -516,7 +527,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
516
527
 
517
528
  @classmethod
518
529
  @abstractmethod
519
- def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
530
+ def _update_rbln_config(cls, **rbln_config_kwargs) -> RBLNModelConfig:
520
531
  pass
521
532
 
522
533
  @classmethod
@@ -524,8 +535,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
524
535
  def _create_runtimes(
525
536
  cls,
526
537
  compiled_models: List[rebel.RBLNCompiledModel],
527
- rbln_device_map: Dict[str, int],
528
- activate_profiler: Optional[bool] = None,
538
+ rbln_config: RBLNModelConfig,
529
539
  ) -> List[rebel.Runtime]:
530
540
  # compiled_models -> runtimes
531
541
  pass
@@ -537,11 +547,11 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
537
547
 
538
548
  @classmethod
539
549
  @abstractmethod
540
- @use_rbln_config
541
550
  def from_model(
542
551
  cls,
543
552
  model: "PreTrainedModel",
544
- rbln_config: Dict[str, Any] = {},
553
+ config: Optional[PretrainedConfig] = None,
554
+ rbln_config: Optional[RBLNModelConfig] = None,
545
555
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
546
556
  subfolder: str = "",
547
557
  **kwargs,
@@ -18,7 +18,15 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
- "cache_utils": ["RebelDynamicCache"],
21
+ "configuration_alias": [
22
+ "RBLNASTForAudioClassificationConfig",
23
+ "RBLNDistilBertForQuestionAnsweringConfig",
24
+ "RBLNResNetForImageClassificationConfig",
25
+ "RBLNXLMRobertaForSequenceClassificationConfig",
26
+ "RBLNRobertaForSequenceClassificationConfig",
27
+ "RBLNRobertaForMaskedLMConfig",
28
+ "RBLNViTForImageClassificationConfig",
29
+ ],
22
30
  "models": [
23
31
  "RBLNAutoModel",
24
32
  "RBLNAutoModelForAudioClassification",
@@ -33,30 +41,58 @@ _import_structure = {
33
41
  "RBLNAutoModelForSpeechSeq2Seq",
34
42
  "RBLNAutoModelForVision2Seq",
35
43
  "RBLNBartForConditionalGeneration",
44
+ "RBLNBartForConditionalGenerationConfig",
36
45
  "RBLNBartModel",
37
- "RBLNBertModel",
46
+ "RBLNBartModelConfig",
38
47
  "RBLNBertForMaskedLM",
48
+ "RBLNBertForMaskedLMConfig",
39
49
  "RBLNBertForQuestionAnswering",
50
+ "RBLNBertForQuestionAnsweringConfig",
51
+ "RBLNBertModel",
52
+ "RBLNBertModelConfig",
40
53
  "RBLNCLIPTextModel",
54
+ "RBLNCLIPTextModelConfig",
41
55
  "RBLNCLIPTextModelWithProjection",
56
+ "RBLNCLIPTextModelWithProjectionConfig",
42
57
  "RBLNCLIPVisionModel",
58
+ "RBLNCLIPVisionModelConfig",
43
59
  "RBLNCLIPVisionModelWithProjection",
60
+ "RBLNCLIPVisionModelWithProjectionConfig",
61
+ "RBLNDecoderOnlyModelForCausalLM",
62
+ "RBLNDecoderOnlyModelForCausalLMConfig",
44
63
  "RBLNDPTForDepthEstimation",
64
+ "RBLNDPTForDepthEstimationConfig",
45
65
  "RBLNExaoneForCausalLM",
66
+ "RBLNExaoneForCausalLMConfig",
46
67
  "RBLNGemmaForCausalLM",
68
+ "RBLNGemmaForCausalLMConfig",
47
69
  "RBLNGPT2LMHeadModel",
48
- "RBLNQwen2ForCausalLM",
49
- "RBLNWav2Vec2ForCTC",
50
- "RBLNWhisperForConditionalGeneration",
70
+ "RBLNGPT2LMHeadModelConfig",
51
71
  "RBLNLlamaForCausalLM",
72
+ "RBLNLlamaForCausalLMConfig",
73
+ "RBLNLlavaNextForConditionalGeneration",
74
+ "RBLNLlavaNextForConditionalGenerationConfig",
75
+ "RBLNMidmLMHeadModel",
76
+ "RBLNMidmLMHeadModelConfig",
77
+ "RBLNMistralForCausalLM",
78
+ "RBLNMistralForCausalLMConfig",
52
79
  "RBLNPhiForCausalLM",
80
+ "RBLNPhiForCausalLMConfig",
81
+ "RBLNQwen2ForCausalLM",
82
+ "RBLNQwen2ForCausalLMConfig",
53
83
  "RBLNT5EncoderModel",
84
+ "RBLNT5EncoderModelConfig",
54
85
  "RBLNT5ForConditionalGeneration",
86
+ "RBLNT5ForConditionalGenerationConfig",
87
+ "RBLNWav2Vec2ForCTC",
88
+ "RBLNWav2Vec2ForCTCConfig",
89
+ "RBLNWhisperForConditionalGeneration",
90
+ "RBLNWhisperForConditionalGenerationConfig",
55
91
  "RBLNTimeSeriesTransformerForPrediction",
92
+ "RBLNTimeSeriesTransformerForPredictionConfig",
56
93
  "RBLNLlavaNextForConditionalGeneration",
57
- "RBLNMidmLMHeadModel",
58
94
  "RBLNXLMRobertaModel",
59
- "RBLNMistralForCausalLM",
95
+ "RBLNXLMRobertaModelConfig",
60
96
  ],
61
97
  "modeling_alias": [
62
98
  "RBLNASTForAudioClassification",
@@ -70,7 +106,15 @@ _import_structure = {
70
106
  }
71
107
 
72
108
  if TYPE_CHECKING:
73
- from .cache_utils import RebelDynamicCache
109
+ from .configuration_alias import (
110
+ RBLNASTForAudioClassificationConfig,
111
+ RBLNDistilBertForQuestionAnsweringConfig,
112
+ RBLNResNetForImageClassificationConfig,
113
+ RBLNRobertaForMaskedLMConfig,
114
+ RBLNRobertaForSequenceClassificationConfig,
115
+ RBLNViTForImageClassificationConfig,
116
+ RBLNXLMRobertaForSequenceClassificationConfig,
117
+ )
74
118
  from .modeling_alias import (
75
119
  RBLNASTForAudioClassification,
76
120
  RBLNDistilBertForQuestionAnswering,
@@ -94,30 +138,57 @@ if TYPE_CHECKING:
94
138
  RBLNAutoModelForSpeechSeq2Seq,
95
139
  RBLNAutoModelForVision2Seq,
96
140
  RBLNBartForConditionalGeneration,
141
+ RBLNBartForConditionalGenerationConfig,
97
142
  RBLNBartModel,
143
+ RBLNBartModelConfig,
98
144
  RBLNBertForMaskedLM,
145
+ RBLNBertForMaskedLMConfig,
99
146
  RBLNBertForQuestionAnswering,
147
+ RBLNBertForQuestionAnsweringConfig,
100
148
  RBLNBertModel,
149
+ RBLNBertModelConfig,
101
150
  RBLNCLIPTextModel,
151
+ RBLNCLIPTextModelConfig,
102
152
  RBLNCLIPTextModelWithProjection,
153
+ RBLNCLIPTextModelWithProjectionConfig,
103
154
  RBLNCLIPVisionModel,
155
+ RBLNCLIPVisionModelConfig,
104
156
  RBLNCLIPVisionModelWithProjection,
157
+ RBLNCLIPVisionModelWithProjectionConfig,
158
+ RBLNDecoderOnlyModelForCausalLM,
159
+ RBLNDecoderOnlyModelForCausalLMConfig,
105
160
  RBLNDPTForDepthEstimation,
161
+ RBLNDPTForDepthEstimationConfig,
106
162
  RBLNExaoneForCausalLM,
163
+ RBLNExaoneForCausalLMConfig,
107
164
  RBLNGemmaForCausalLM,
165
+ RBLNGemmaForCausalLMConfig,
108
166
  RBLNGPT2LMHeadModel,
167
+ RBLNGPT2LMHeadModelConfig,
109
168
  RBLNLlamaForCausalLM,
169
+ RBLNLlamaForCausalLMConfig,
110
170
  RBLNLlavaNextForConditionalGeneration,
171
+ RBLNLlavaNextForConditionalGenerationConfig,
111
172
  RBLNMidmLMHeadModel,
173
+ RBLNMidmLMHeadModelConfig,
112
174
  RBLNMistralForCausalLM,
175
+ RBLNMistralForCausalLMConfig,
113
176
  RBLNPhiForCausalLM,
177
+ RBLNPhiForCausalLMConfig,
114
178
  RBLNQwen2ForCausalLM,
179
+ RBLNQwen2ForCausalLMConfig,
115
180
  RBLNT5EncoderModel,
181
+ RBLNT5EncoderModelConfig,
116
182
  RBLNT5ForConditionalGeneration,
183
+ RBLNT5ForConditionalGenerationConfig,
117
184
  RBLNTimeSeriesTransformerForPrediction,
185
+ RBLNTimeSeriesTransformerForPredictionConfig,
118
186
  RBLNWav2Vec2ForCTC,
187
+ RBLNWav2Vec2ForCTCConfig,
119
188
  RBLNWhisperForConditionalGeneration,
189
+ RBLNWhisperForConditionalGenerationConfig,
120
190
  RBLNXLMRobertaModel,
191
+ RBLNXLMRobertaModelConfig,
121
192
  )
122
193
  else:
123
194
  import sys
@@ -0,0 +1,49 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_generic import (
16
+ RBLNModelForAudioClassificationConfig,
17
+ RBLNModelForImageClassificationConfig,
18
+ RBLNModelForMaskedLMConfig,
19
+ RBLNModelForQuestionAnsweringConfig,
20
+ RBLNModelForSequenceClassificationConfig,
21
+ )
22
+
23
+
24
+ class RBLNASTForAudioClassificationConfig(RBLNModelForAudioClassificationConfig):
25
+ pass
26
+
27
+
28
+ class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
29
+ pass
30
+
31
+
32
+ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
33
+ pass
34
+
35
+
36
+ class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
37
+ pass
38
+
39
+
40
+ class RBLNRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
41
+ pass
42
+
43
+
44
+ class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
45
+ pass
46
+
47
+
48
+ class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
49
+ pass
@@ -0,0 +1,142 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ from ..configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class _RBLNTransformerEncoderConfig(RBLNModelConfig):
21
+ rbln_model_input_names: Optional[List[str]] = None
22
+
23
+ def __init__(
24
+ self,
25
+ max_seq_len: Optional[int] = None,
26
+ batch_size: Optional[int] = None,
27
+ model_input_names: Optional[List[str]] = None,
28
+ **kwargs,
29
+ ):
30
+ """
31
+ Args:
32
+ max_seq_len (Optional[int]): Maximum sequence length supported by the model.
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ model_input_names (Optional[List[str]]): Names of the input tensors for the model.
35
+ Defaults to class-specific rbln_model_input_names if not provided.
36
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.max_seq_len = max_seq_len
43
+ self.batch_size = batch_size or 1
44
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
45
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
+
47
+ self.model_input_names = model_input_names or self.rbln_model_input_names
48
+
49
+
50
+ class _RBLNImageModelConfig(RBLNModelConfig):
51
+ def __init__(
52
+ self, image_size: Optional[Union[int, Tuple[int, int]]] = None, batch_size: Optional[int] = None, **kwargs
53
+ ):
54
+ """
55
+ Args:
56
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
57
+ Can be an integer for square images or a tuple (height, width).
58
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
59
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
60
+
61
+ Raises:
62
+ ValueError: If batch_size is not a positive integer.
63
+ """
64
+ super().__init__(**kwargs)
65
+ self.image_size = image_size
66
+ self.batch_size = batch_size or 1
67
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
68
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
69
+
70
+ @property
71
+ def image_width(self):
72
+ if isinstance(self.image_size, int):
73
+ return self.image_size
74
+ elif isinstance(self.image_size, (list, tuple)):
75
+ return self.image_size[1]
76
+ else:
77
+ return self.image_size["width"]
78
+
79
+ @property
80
+ def image_height(self):
81
+ if isinstance(self.image_size, int):
82
+ return self.image_size
83
+ elif isinstance(self.image_size, (list, tuple)):
84
+ return self.image_size[0]
85
+ else:
86
+ return self.image_size["height"]
87
+
88
+
89
+ class RBLNModelForQuestionAnsweringConfig(_RBLNTransformerEncoderConfig):
90
+ pass
91
+
92
+
93
+ class RBLNModelForSequenceClassificationConfig(_RBLNTransformerEncoderConfig):
94
+ pass
95
+
96
+
97
+ class RBLNModelForMaskedLMConfig(_RBLNTransformerEncoderConfig):
98
+ pass
99
+
100
+
101
+ class RBLNModelForTextEncodingConfig(_RBLNTransformerEncoderConfig):
102
+ pass
103
+
104
+
105
+ # FIXME : Appropriate name ?
106
+ class RBLNTransformerEncoderForFeatureExtractionConfig(_RBLNTransformerEncoderConfig):
107
+ pass
108
+
109
+
110
+ class RBLNModelForImageClassificationConfig(_RBLNImageModelConfig):
111
+ pass
112
+
113
+
114
+ class RBLNModelForDepthEstimationConfig(_RBLNImageModelConfig):
115
+ pass
116
+
117
+
118
+ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
119
+ def __init__(
120
+ self,
121
+ batch_size: Optional[int] = None,
122
+ max_length: Optional[int] = None,
123
+ num_mel_bins: Optional[int] = None,
124
+ **kwargs,
125
+ ):
126
+ """
127
+ Args:
128
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
129
+ max_length (Optional[int]): Maximum length of the audio input in time dimension.
130
+ num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
131
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
132
+
133
+ Raises:
134
+ ValueError: If batch_size is not a positive integer.
135
+ """
136
+ super().__init__(**kwargs)
137
+ self.batch_size = batch_size or 1
138
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
139
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
140
+
141
+ self.max_length = max_length
142
+ self.num_mel_bins = num_mel_bins