optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -20,14 +20,15 @@
20
20
  # are the intellectual property of Rebellions Inc. and may not be
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
+ import copy
23
24
  import importlib
24
25
  from os import PathLike
25
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
26
27
 
27
28
  import torch
28
29
 
29
- from .modeling_base import RBLNModel
30
- from .modeling_config import ContextRblnConfig, use_rbln_config
30
+ from .modeling import RBLNModel
31
+ from .modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
31
32
  from .utils.decorator_utils import remove_compile_time_kwargs
32
33
 
33
34
 
@@ -74,127 +75,40 @@ class RBLNDiffusionMixin:
74
75
 
75
76
  @classmethod
76
77
  @property
77
- def use_encode(cls):
78
+ def img2img_pipeline(cls):
78
79
  return "Img2Img" in cls.__name__
79
80
 
80
81
  @classmethod
81
- def _get_unet_batch_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> int:
82
- # Calculates the batch size based on guidance scale
83
- batch_size = rbln_config.get("batch_size", 1)
84
- do_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
85
- return batch_size * 2 if do_guidance else batch_size
86
-
87
- @classmethod
88
- def _get_vae_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
89
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
90
- if (image_size[0] is None) != (image_size[1] is None):
91
- raise ValueError("Both image height and image width must be given or not given")
92
- elif image_size[0] is None and image_size[1] is None:
93
- if cls.use_encode:
94
- sample_size = model.vae.config.sample_size
95
- else:
96
- # In case of text2img, sample size of vae decoder is determined by unet.
97
- unet_sample_size = model.unet.config.sample_size
98
- if isinstance(unet_sample_size, int):
99
- sample_size = unet_sample_size * model.vae_scale_factor
100
- else:
101
- sample_size = (
102
- unet_sample_size[0] * model.vae_scale_factor,
103
- unet_sample_size[1] * model.vae_scale_factor,
104
- )
105
-
106
- else:
107
- sample_size = (image_size[0], image_size[1])
108
- return sample_size
109
-
110
- @classmethod
111
- def _get_unet_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
112
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
113
- if (image_size[0] is None) != (image_size[1] is None):
114
- raise ValueError("Both image height and image width must be given or not given")
115
- elif image_size[0] is None and image_size[1] is None:
116
- if cls.use_encode:
117
- # In case of img2img, sample size of unet is determined by vae encoder.
118
- vae_sample_size = model.vae.config.sample_size
119
- if isinstance(vae_sample_size, int):
120
- sample_size = vae_sample_size // model.vae_scale_factor
121
- else:
122
- sample_size = (
123
- vae_sample_size[0] // model.vae_scale_factor,
124
- vae_sample_size[1] // model.vae_scale_factor,
125
- )
126
- else:
127
- sample_size = model.unet.config.sample_size
128
- else:
129
- sample_size = (image_size[0] // model.vae_scale_factor, image_size[1] // model.vae_scale_factor)
130
- return sample_size
131
-
132
- @classmethod
133
- def _get_default_config(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
134
- # default configurations for each submodules
135
- return {"img2img_pipeline": cls.use_encode}
82
+ @property
83
+ def inpaint_pipeline(cls):
84
+ return "Inpaint" in cls.__name__
136
85
 
137
86
  @classmethod
138
- def get_default_rbln_config_text_encoder(
139
- cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
87
+ def get_submodule_rbln_config(
88
+ cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
140
89
  ) -> Dict[str, Any]:
141
- batch_size = rbln_config.get("batch_size", 1)
142
- return {"batch_size": batch_size}
90
+ submodule = getattr(model, submodule_name)
91
+ submodule_class_name = submodule.__class__.__name__
143
92
 
144
- @classmethod
145
- def get_default_rbln_config_text_encoder_2(
146
- cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
147
- ) -> Dict[str, Any]:
148
- batch_size = rbln_config.get("batch_size", 1)
149
- return {"batch_size": batch_size}
93
+ if submodule_class_name == "MultiControlNetModel":
94
+ submodule_class_name = "ControlNetModel"
150
95
 
151
- @classmethod
152
- def get_default_rbln_config_unet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
153
- # configuration for unet
154
- unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
155
- text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
156
- return {
157
- **cls._get_default_config(model, rbln_config),
158
- "max_seq_len": model.text_encoder.config.max_position_embeddings,
159
- "text_model_hidden_size": text_model_hidden_size,
160
- "batch_size": unet_batch_size,
161
- "sample_size": cls._get_unet_sample_size(model, rbln_config),
162
- "is_controlnet": "controlnet" in model.config.keys(),
163
- }
96
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
164
97
 
165
- @classmethod
166
- def get_default_rbln_config_vae(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
167
- # configuration for vae
168
- batch_size = rbln_config.get("batch_size", 1)
169
- return {
170
- **cls._get_default_config(model, rbln_config),
171
- "sample_size": cls._get_vae_sample_size(model, rbln_config),
172
- "batch_size": batch_size,
173
- }
98
+ submodule_config = rbln_config.get(submodule_name, {})
99
+ submodule_config = copy.deepcopy(submodule_config)
174
100
 
175
- @classmethod
176
- def get_default_rbln_config_controlnet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
177
- # configuration for controlnet
178
- unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
179
- text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
180
- return {
181
- **cls._get_default_config(model, rbln_config),
182
- "max_seq_len": model.text_encoder.config.max_position_embeddings,
183
- "vae_sample_size": cls._get_vae_sample_size(model, rbln_config),
184
- "unet_sample_size": cls._get_unet_sample_size(model, rbln_config),
185
- "batch_size": unet_batch_size,
186
- "text_model_hidden_size": text_model_hidden_size,
187
- }
101
+ pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
188
102
 
189
- @classmethod
190
- def get_default_rbln_config(
191
- cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
192
- ) -> Dict[str, Any]:
193
- # Returns the default configuration based on submodule name
194
- config_method = f"get_default_rbln_config_{submodule_name}"
195
- if hasattr(cls, config_method):
196
- return getattr(cls, config_method)(model, rbln_config)
197
- raise ValueError(f"Unknown submodule: {submodule_name}")
103
+ submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
104
+ submodule_config.update(
105
+ {
106
+ "img2img_pipeline": cls.img2img_pipeline,
107
+ "inpaint_pipeline": cls.inpaint_pipeline,
108
+ }
109
+ )
110
+ submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
111
+ return submodule_config
198
112
 
199
113
  @staticmethod
200
114
  def _maybe_apply_and_fuse_lora(
@@ -256,12 +170,32 @@ class RBLNDiffusionMixin:
256
170
 
257
171
  else:
258
172
  # raise error if any of submodules are torch module.
259
- for name in cls._submodules:
260
- if isinstance(kwargs.get(name), torch.nn.Module):
173
+ model_index_config = None
174
+ for submodule_name in cls._submodules:
175
+ if isinstance(kwargs.get(submodule_name), torch.nn.Module):
261
176
  raise AssertionError(
262
- f"{name} is not compiled torch module. If you want to compile, set `export=True`."
177
+ f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
263
178
  )
264
179
 
180
+ # Load submodule outside if runtime kwargs(e.g. device) is specified.
181
+ if submodule_config := rbln_config.get(submodule_name):
182
+ if any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
183
+ if model_index_config is None:
184
+ model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
185
+
186
+ module_name, class_name = model_index_config[submodule_name]
187
+ if module_name != "optimum.rbln":
188
+ raise ValueError(
189
+ f"Invalid module_name '{module_name}' found in model_index.json for "
190
+ f"submodule '{submodule_name}'. "
191
+ "Expected 'optimum.rbln'. Please check the model_index.json configuration."
192
+ )
193
+ submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
194
+ submodule = submodule_cls.from_pretrained(
195
+ model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
196
+ )
197
+ kwargs[submodule_name] = submodule
198
+
265
199
  with ContextRblnConfig(
266
200
  device=rbln_config.get("device"),
267
201
  device_map=rbln_config.get("device_map"),
@@ -291,16 +225,11 @@ class RBLNDiffusionMixin:
291
225
  model_save_dir: Optional[PathLike],
292
226
  rbln_config: Dict[str, Any],
293
227
  ) -> Dict[str, RBLNModel]:
294
- # Compile submodules based on rbln_config
295
228
  compiled_submodules = {}
296
229
 
297
- # FIXME : Currently, optimum-rbln for transformer does not use base rbln config.
298
- base_rbln_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
299
230
  for submodule_name in cls._submodules:
300
231
  submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
301
- submodule_rbln_config = cls.get_default_rbln_config(model, submodule_name, rbln_config)
302
- submodule_rbln_config.update(base_rbln_config)
303
- submodule_rbln_config.update(rbln_config.get(submodule_name, {}))
232
+ submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
304
233
 
305
234
  if submodule is None:
306
235
  raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
@@ -28,7 +28,6 @@ from transformers.utils import _LazyModule
28
28
 
29
29
  _import_structure = {
30
30
  "cache_utils": ["RebelDynamicCache"],
31
- "generation": ["BatchTextIteratorStreamer"],
32
31
  "models": [
33
32
  "RBLNAutoModel",
34
33
  "RBLNAutoModelForAudioClassification",
@@ -68,7 +67,6 @@ _import_structure = {
68
67
 
69
68
  if TYPE_CHECKING:
70
69
  from .cache_utils import RebelDynamicCache
71
- from .generation import BatchTextIteratorStreamer
72
70
  from .models import (
73
71
  RBLNAutoModel,
74
72
  RBLNAutoModelForAudioClassification,
@@ -22,8 +22,16 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import importlib
25
+ import inspect
26
+ import warnings
25
27
 
26
- from transformers import AutoConfig
28
+ from transformers import AutoConfig, PretrainedConfig
29
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
30
+ from transformers.models.auto.auto_factory import _get_model_class
31
+
32
+ from optimum.rbln.modeling_base import RBLNBaseModel
33
+ from optimum.rbln.modeling_config import RBLNConfig
34
+ from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
27
35
 
28
36
 
29
37
  class _BaseAutoModelClass:
@@ -33,46 +41,132 @@ class _BaseAutoModelClass:
33
41
  def __init__(self, *args, **kwargs):
34
42
  raise EnvironmentError(
35
43
  f"{self.__class__.__name__} is designed to be instantiated "
36
- f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
37
- f"`{self.__class__.__name__}.from_config(config)` methods."
44
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`"
38
45
  )
39
46
 
40
47
  @classmethod
41
48
  def get_rbln_cls(
42
49
  cls,
43
- model_id,
50
+ pretrained_model_name_or_path,
44
51
  *args,
52
+ export=True,
45
53
  **kwargs,
46
54
  ):
47
- # kwargs.update({"return_unused_kwargs": True})
48
- config = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, **kwargs)[0]
49
-
50
- if len(config.architectures) > 1:
51
- raise ValueError(
52
- f"Model with ID '{model_id}' has multiple architectures defined in the configuration: "
53
- f"{config.architectures}. `_BaseAutoModelClass` require exactly one architecture. "
54
- )
55
-
56
- architecture_name = config.architectures[0]
57
- if architecture_name not in cls._model_mapping.values():
58
- raise ValueError(
59
- f"The 'RBLN{architecture_name}' architecture is not supported by `{cls.__name__}.from_pretrained()`."
60
- "Please use the appropriate class's `from_pretrained()` method to load this model."
61
- )
62
-
63
- rbln_class_name = "RBLN" + architecture_name
64
- module = importlib.import_module("optimum.rbln")
55
+ """
56
+ Determine the appropriate RBLN model class based on the given model ID and configuration.
57
+
58
+ Args:
59
+ pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
60
+ export (bool): Whether to infer the class based on Hugging Face (HF) architecture.
61
+ kwargs: Additional arguments for configuration and loading.
62
+
63
+ Returns:
64
+ RBLNBaseModel: The corresponding RBLN model class.
65
+ """
66
+ if export:
67
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
68
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
69
+ else:
70
+ rbln_class_name = cls.get_rbln_model_class_name(pretrained_model_name_or_path, **kwargs)
71
+
72
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
73
+ raise ValueError(
74
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
75
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
76
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
77
+ )
65
78
 
66
79
  try:
80
+ module = importlib.import_module("optimum.rbln")
67
81
  rbln_cls = getattr(module, rbln_class_name)
68
82
  except AttributeError as e:
69
83
  raise AttributeError(
70
- f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{model_id}'. "
84
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
71
85
  "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
72
86
  ) from e
73
87
 
74
88
  return rbln_cls
75
89
 
90
+ @classmethod
91
+ def infer_hf_model_class(
92
+ cls,
93
+ pretrained_model_name_or_path,
94
+ *args,
95
+ **kwargs,
96
+ ):
97
+ """
98
+ Infer the Hugging Face model class based on the configuration or model name.
99
+
100
+ Args:
101
+ pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
102
+ kwargs: Additional arguments for configuration and loading.
103
+
104
+ Returns:
105
+ PretrainedModel: The inferred Hugging Face model class.
106
+ """
107
+
108
+ # Try to load configuration if provided or retrieve it from the model ID
109
+ config = kwargs.pop("config", None)
110
+ kwargs.update({"trust_remote_code": True})
111
+ kwargs["_from_auto"] = True
112
+
113
+ # Load configuration if not already provided
114
+ if not isinstance(config, PretrainedConfig):
115
+ config, kwargs = AutoConfig.from_pretrained(
116
+ pretrained_model_name_or_path,
117
+ return_unused_kwargs=True,
118
+ **kwargs,
119
+ )
120
+
121
+ # Get hf_model_class from Config
122
+ has_remote_code = (
123
+ hasattr(config, "auto_map") and convert_rbln_to_hf_model_name(cls.__name__) in config.auto_map
124
+ )
125
+ if has_remote_code:
126
+ class_ref = config.auto_map[convert_rbln_to_hf_model_name(cls.__name__)]
127
+ model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
128
+ elif type(config) in cls._model_mapping.keys():
129
+ model_class = _get_model_class(config, cls._model_mapping)
130
+ else:
131
+ raise ValueError(
132
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
133
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
134
+ )
135
+
136
+ if model_class.__name__ != config.architectures[0]:
137
+ warnings.warn(
138
+ f"`{cls.__name__}.from_pretrained()` is invoking `{convert_hf_to_rbln_model_name(model_class.__name__)}.from_pretrained()`, which does not match the "
139
+ f"expected architecture `RBLN{config.architectures[0]}` from config. This mismatch could cause some operations to not be properly loaded "
140
+ f"from the checkpoint, leading to potential unintended behavior. If this is not intentional, consider calling the "
141
+ f"`from_pretrained()` method directly from the `RBLN{config.architectures[0]}` class instead.",
142
+ UserWarning,
143
+ )
144
+
145
+ return model_class
146
+
147
+ @classmethod
148
+ def get_rbln_model_class_name(cls, pretrained_model_name_or_path, **kwargs):
149
+ """
150
+ Retrieve the path to the compiled model directory for a given RBLN model.
151
+
152
+ Args:
153
+ pretrained_model_name_or_path (str): Identifier of the model.
154
+ kwargs: Additional arguments that match the parameters of `_load_compiled_model_dir`.
155
+
156
+ Returns:
157
+ str: Path to the compiled model directory.
158
+ """
159
+ sig = inspect.signature(RBLNBaseModel._load_compiled_model_dir)
160
+ valid_params = sig.parameters.keys()
161
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
162
+
163
+ model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
164
+ model_id=pretrained_model_name_or_path, **filtered_kwargs
165
+ )
166
+ rbln_config = RBLNConfig.load(model_path_subfolder)
167
+
168
+ return rbln_config.meta["cls"]
169
+
76
170
  @classmethod
77
171
  def from_pretrained(
78
172
  cls,
@@ -21,18 +21,31 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+
24
25
  from transformers.models.auto.modeling_auto import (
26
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
25
27
  MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
28
+ MODEL_FOR_CAUSAL_LM_MAPPING,
26
29
  MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
30
+ MODEL_FOR_CTC_MAPPING,
27
31
  MODEL_FOR_CTC_MAPPING_NAMES,
32
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
28
33
  MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
34
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
29
35
  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
36
+ MODEL_FOR_MASKED_LM_MAPPING,
30
37
  MODEL_FOR_MASKED_LM_MAPPING_NAMES,
38
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
31
39
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
40
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
32
41
  MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
42
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
33
43
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
44
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
34
45
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
46
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
35
47
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
48
+ MODEL_MAPPING,
36
49
  MODEL_MAPPING_NAMES,
37
50
  )
38
51
 
@@ -48,48 +61,60 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
48
61
 
49
62
 
50
63
  class RBLNAutoModel(_BaseAutoModelClass):
51
- _model_mapping = MODEL_MAPPING_NAMES
64
+ _model_mapping = MODEL_MAPPING
65
+ _model_mapping_names = MODEL_MAPPING_NAMES
52
66
 
53
67
 
54
68
  class RBLNAutoModelForCTC(_BaseAutoModelClass):
55
- _model_mapping = MODEL_FOR_CTC_MAPPING_NAMES
69
+ _model_mapping = MODEL_FOR_CTC_MAPPING
70
+ _model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
56
71
 
57
72
 
58
73
  class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
59
- _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
74
+ _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
75
+ _model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
60
76
 
61
77
 
62
78
  class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
63
- _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
79
+ _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
80
+ _model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
64
81
 
65
82
 
66
83
  class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
67
- _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
84
+ _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
85
+ _model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
68
86
 
69
87
 
70
88
  class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
71
- _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
89
+ _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
90
+ _model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
72
91
 
73
92
 
74
93
  class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
75
- _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
94
+ _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
95
+ _model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
76
96
 
77
97
 
78
98
  class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
79
- _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
99
+ _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
100
+ _model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
80
101
 
81
102
 
82
103
  class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
83
- _model_mapping = MODEL_FOR_MASKED_LM_MAPPING_NAMES
104
+ _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
105
+ _model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
84
106
 
85
107
 
86
108
  class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
87
- _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
109
+ _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
110
+ _model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
88
111
 
89
112
 
90
113
  class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
91
- _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
114
+ _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
115
+ _model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
92
116
 
93
117
 
94
118
  class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
95
- _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
119
+ _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
120
+ _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
@@ -24,9 +24,9 @@
24
24
  import inspect
25
25
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
26
 
27
- from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
27
+ from transformers import BartForConditionalGeneration, PretrainedConfig
28
28
 
29
- from ....modeling_base import RBLNModel
29
+ from ....modeling import RBLNModel
30
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
@@ -41,9 +41,6 @@ if TYPE_CHECKING:
41
41
 
42
42
 
43
43
  class RBLNBartModel(RBLNModel):
44
- original_model_class = BartModel
45
- original_config_class = BartConfig
46
-
47
44
  @classmethod
48
45
  def _get_rbln_config(
49
46
  cls,
@@ -82,7 +79,7 @@ class RBLNBartModel(RBLNModel):
82
79
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
83
80
  rbln_model_input_names = cls.rbln_model_input_names
84
81
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
85
- input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
82
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
86
83
  raise ValueError(
87
84
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
88
85
  f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
@@ -25,9 +25,9 @@ import inspect
25
25
  import logging
26
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
- from transformers import BertConfig, BertModel, PretrainedConfig
28
+ from transformers import PretrainedConfig
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -38,9 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNBertModel(RBLNModel):
41
- original_model_class = BertModel
42
- original_config_class = BertConfig
43
-
44
41
  @classmethod
45
42
  def _get_rbln_config(
46
43
  cls,
@@ -75,7 +72,7 @@ class RBLNBertModel(RBLNModel):
75
72
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
76
73
  rbln_model_input_names = cls.rbln_model_input_names
77
74
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
78
- input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
75
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
79
76
  raise ValueError(
80
77
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
81
78
  f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
@@ -28,16 +28,15 @@ import torch
28
28
  from transformers import (
29
29
  CLIPTextConfig,
30
30
  CLIPTextModel,
31
- CLIPTextModelWithProjection,
32
31
  CLIPVisionConfig,
33
32
  CLIPVisionModel,
34
33
  )
35
34
  from transformers.modeling_outputs import BaseModelOutputWithPooling
36
35
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput
37
36
 
38
- from ....modeling_base import RBLNModel
37
+ from ....modeling import RBLNModel
39
38
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
40
- from ....utils.context import override_auto_classes
39
+ from ....modeling_diffusers import RBLNDiffusionMixin
41
40
 
42
41
 
43
42
  logger = logging.getLogger(__name__)
@@ -57,20 +56,14 @@ class _TextEncoder(torch.nn.Module):
57
56
 
58
57
 
59
58
  class RBLNCLIPTextModel(RBLNModel):
60
- @classmethod
61
- def from_pretrained(cls, *args, **kwargs):
62
- with override_auto_classes(
63
- config_func=CLIPTextConfig.from_pretrained,
64
- model_func=CLIPTextModel.from_pretrained,
65
- skip_taskmanager=False,
66
- ):
67
- rt = super().from_pretrained(*args, **kwargs)
68
- return rt
69
-
70
59
  @classmethod
71
60
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
72
61
  return _TextEncoder(model).eval()
73
62
 
63
+ @classmethod
64
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
65
+ return rbln_config
66
+
74
67
  @classmethod
75
68
  def _get_rbln_config(
76
69
  cls,
@@ -114,7 +107,7 @@ class RBLNCLIPTextModel(RBLNModel):
114
107
 
115
108
 
116
109
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
117
- original_model_class = CLIPTextModelWithProjection
110
+ pass
118
111
 
119
112
 
120
113
  class _VisionEncoder(torch.nn.Module):
@@ -128,16 +121,6 @@ class _VisionEncoder(torch.nn.Module):
128
121
 
129
122
 
130
123
  class RBLNCLIPVisionModel(RBLNModel):
131
- @classmethod
132
- def from_pretrained(cls, *args, **kwargs):
133
- with override_auto_classes(
134
- config_func=CLIPVisionConfig.from_pretrained,
135
- model_func=CLIPVisionModel.from_pretrained,
136
- skip_taskmanager=False,
137
- ):
138
- rt = super().from_pretrained(*args, **kwargs)
139
- return rt
140
-
141
124
  @classmethod
142
125
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
143
126
  return _VisionEncoder(model).eval()
@@ -146,7 +129,7 @@ class RBLNCLIPVisionModel(RBLNModel):
146
129
  def _get_rbln_config(
147
130
  cls,
148
131
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
149
- model_config: "CLIPTextConfig",
132
+ model_config: "CLIPVisionConfig",
150
133
  rbln_kwargs: Dict[str, Any] = {},
151
134
  ) -> RBLNConfig:
152
135
  rbln_batch_size = rbln_kwargs.get("batch_size", 1)
@@ -22,9 +22,6 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .decoderonly_architecture import (
25
- DecoderOnlyAttention,
26
- DecoderOnlyDecoderLayer,
27
- DecoderOnlyModel,
28
25
  DecoderOnlyWrapper,
29
26
  RotaryEmbedding,
30
27
  apply_rotary_pos_emb,