optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,35 +21,84 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ from typing import TYPE_CHECKING
24
25
 
25
- from .auto import (
26
- RBLNAutoModel,
27
- RBLNAutoModelForAudioClassification,
28
- RBLNAutoModelForCausalLM,
29
- RBLNAutoModelForCTC,
30
- RBLNAutoModelForDepthEstimation,
31
- RBLNAutoModelForImageClassification,
32
- RBLNAutoModelForMaskedLM,
33
- RBLNAutoModelForQuestionAnswering,
34
- RBLNAutoModelForSeq2SeqLM,
35
- RBLNAutoModelForSequenceClassification,
36
- RBLNAutoModelForSpeechSeq2Seq,
37
- RBLNAutoModelForVision2Seq,
38
- )
39
- from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
40
- from .bert import RBLNBertModel
41
- from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
42
- from .dpt import RBLNDPTForDepthEstimation
43
- from .exaone import RBLNExaoneForCausalLM
44
- from .gemma import RBLNGemmaForCausalLM
45
- from .gpt2 import RBLNGPT2LMHeadModel
46
- from .llama import RBLNLlamaForCausalLM
47
- from .llava_next import RBLNLlavaNextForConditionalGeneration
48
- from .midm import RBLNMidmLMHeadModel
49
- from .mistral import RBLNMistralForCausalLM
50
- from .phi import RBLNPhiForCausalLM
51
- from .qwen2 import RBLNQwen2ForCausalLM
52
- from .t5 import RBLNT5ForConditionalGeneration
53
- from .wav2vec2 import RBLNWav2Vec2ForCTC
54
- from .whisper import RBLNWhisperForConditionalGeneration
55
- from .xlm_roberta import RBLNXLMRobertaModel
26
+ from transformers.utils import _LazyModule
27
+
28
+
29
+ _import_structure = {
30
+ "auto": [
31
+ "RBLNAutoModel",
32
+ "RBLNAutoModelForAudioClassification",
33
+ "RBLNAutoModelForCausalLM",
34
+ "RBLNAutoModelForCTC",
35
+ "RBLNAutoModelForDepthEstimation",
36
+ "RBLNAutoModelForImageClassification",
37
+ "RBLNAutoModelForMaskedLM",
38
+ "RBLNAutoModelForQuestionAnswering",
39
+ "RBLNAutoModelForSeq2SeqLM",
40
+ "RBLNAutoModelForSequenceClassification",
41
+ "RBLNAutoModelForSpeechSeq2Seq",
42
+ "RBLNAutoModelForVision2Seq",
43
+ ],
44
+ "bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
45
+ "bert": ["RBLNBertModel"],
46
+ "clip": ["RBLNCLIPTextModel", "RBLNCLIPTextModelWithProjection", "RBLNCLIPVisionModel"],
47
+ "dpt": ["RBLNDPTForDepthEstimation"],
48
+ "exaone": ["RBLNExaoneForCausalLM"],
49
+ "gemma": ["RBLNGemmaForCausalLM"],
50
+ "gpt2": ["RBLNGPT2LMHeadModel"],
51
+ "llama": ["RBLNLlamaForCausalLM"],
52
+ "llava_next": ["RBLNLlavaNextForConditionalGeneration"],
53
+ "midm": ["RBLNMidmLMHeadModel"],
54
+ "mistral": ["RBLNMistralForCausalLM"],
55
+ "phi": ["RBLNPhiForCausalLM"],
56
+ "qwen2": ["RBLNQwen2ForCausalLM"],
57
+ "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
58
+ "wav2vec2": ["RBLNWav2Vec2ForCTC"],
59
+ "whisper": ["RBLNWhisperForConditionalGeneration"],
60
+ "xlm_roberta": ["RBLNXLMRobertaModel"],
61
+ }
62
+
63
+ if TYPE_CHECKING:
64
+ from .auto import (
65
+ RBLNAutoModel,
66
+ RBLNAutoModelForAudioClassification,
67
+ RBLNAutoModelForCausalLM,
68
+ RBLNAutoModelForCTC,
69
+ RBLNAutoModelForDepthEstimation,
70
+ RBLNAutoModelForImageClassification,
71
+ RBLNAutoModelForMaskedLM,
72
+ RBLNAutoModelForQuestionAnswering,
73
+ RBLNAutoModelForSeq2SeqLM,
74
+ RBLNAutoModelForSequenceClassification,
75
+ RBLNAutoModelForSpeechSeq2Seq,
76
+ RBLNAutoModelForVision2Seq,
77
+ )
78
+ from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
79
+ from .bert import RBLNBertModel
80
+ from .clip import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection, RBLNCLIPVisionModel
81
+ from .dpt import RBLNDPTForDepthEstimation
82
+ from .exaone import RBLNExaoneForCausalLM
83
+ from .gemma import RBLNGemmaForCausalLM
84
+ from .gpt2 import RBLNGPT2LMHeadModel
85
+ from .llama import RBLNLlamaForCausalLM
86
+ from .llava_next import RBLNLlavaNextForConditionalGeneration
87
+ from .midm import RBLNMidmLMHeadModel
88
+ from .mistral import RBLNMistralForCausalLM
89
+ from .phi import RBLNPhiForCausalLM
90
+ from .qwen2 import RBLNQwen2ForCausalLM
91
+ from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
92
+ from .wav2vec2 import RBLNWav2Vec2ForCTC
93
+ from .whisper import RBLNWhisperForConditionalGeneration
94
+ from .xlm_roberta import RBLNXLMRobertaModel
95
+
96
+ else:
97
+ import sys
98
+
99
+ sys.modules[__name__] = _LazyModule(
100
+ __name__,
101
+ globals()["__file__"],
102
+ _import_structure,
103
+ module_spec=__spec__,
104
+ )
@@ -22,8 +22,16 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import importlib
25
+ import inspect
26
+ import warnings
25
27
 
26
- from transformers import AutoConfig
28
+ from transformers import AutoConfig, PretrainedConfig
29
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
30
+ from transformers.models.auto.auto_factory import _get_model_class
31
+
32
+ from optimum.rbln.modeling_base import RBLNBaseModel
33
+ from optimum.rbln.modeling_config import RBLNConfig
34
+ from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
27
35
 
28
36
 
29
37
  class _BaseAutoModelClass:
@@ -33,46 +41,132 @@ class _BaseAutoModelClass:
33
41
  def __init__(self, *args, **kwargs):
34
42
  raise EnvironmentError(
35
43
  f"{self.__class__.__name__} is designed to be instantiated "
36
- f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
37
- f"`{self.__class__.__name__}.from_config(config)` methods."
44
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`"
38
45
  )
39
46
 
40
47
  @classmethod
41
48
  def get_rbln_cls(
42
49
  cls,
43
- model_id,
50
+ pretrained_model_name_or_path,
44
51
  *args,
52
+ export=True,
45
53
  **kwargs,
46
54
  ):
47
- # kwargs.update({"return_unused_kwargs": True})
48
- config = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, **kwargs)[0]
49
-
50
- if len(config.architectures) > 1:
51
- raise ValueError(
52
- f"Model with ID '{model_id}' has multiple architectures defined in the configuration: "
53
- f"{config.architectures}. `_BaseAutoModelClass` require exactly one architecture. "
54
- )
55
-
56
- architecture_name = config.architectures[0]
57
- if architecture_name not in cls._model_mapping.values():
58
- raise ValueError(
59
- f"The 'RBLN{architecture_name}' architecture is not supported by `{cls.__name__}.from_pretrained()`."
60
- "Please use the appropriate class's `from_pretrained()` method to load this model."
61
- )
62
-
63
- rbln_class_name = "RBLN" + architecture_name
64
- module = importlib.import_module("optimum.rbln")
55
+ """
56
+ Determine the appropriate RBLN model class based on the given model ID and configuration.
57
+
58
+ Args:
59
+ pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
60
+ export (bool): Whether to infer the class based on Hugging Face (HF) architecture.
61
+ kwargs: Additional arguments for configuration and loading.
62
+
63
+ Returns:
64
+ RBLNBaseModel: The corresponding RBLN model class.
65
+ """
66
+ if export:
67
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
68
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
69
+ else:
70
+ rbln_class_name = cls.get_rbln_model_class_name(pretrained_model_name_or_path, **kwargs)
71
+
72
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
73
+ raise ValueError(
74
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
75
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
76
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
77
+ )
65
78
 
66
79
  try:
80
+ module = importlib.import_module("optimum.rbln")
67
81
  rbln_cls = getattr(module, rbln_class_name)
68
82
  except AttributeError as e:
69
83
  raise AttributeError(
70
- f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{model_id}'. "
84
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
71
85
  "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
72
86
  ) from e
73
87
 
74
88
  return rbln_cls
75
89
 
90
+ @classmethod
91
+ def infer_hf_model_class(
92
+ cls,
93
+ pretrained_model_name_or_path,
94
+ *args,
95
+ **kwargs,
96
+ ):
97
+ """
98
+ Infer the Hugging Face model class based on the configuration or model name.
99
+
100
+ Args:
101
+ pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
102
+ kwargs: Additional arguments for configuration and loading.
103
+
104
+ Returns:
105
+ PretrainedModel: The inferred Hugging Face model class.
106
+ """
107
+
108
+ # Try to load configuration if provided or retrieve it from the model ID
109
+ config = kwargs.pop("config", None)
110
+ kwargs.update({"trust_remote_code": True})
111
+ kwargs["_from_auto"] = True
112
+
113
+ # Load configuration if not already provided
114
+ if not isinstance(config, PretrainedConfig):
115
+ config, kwargs = AutoConfig.from_pretrained(
116
+ pretrained_model_name_or_path,
117
+ return_unused_kwargs=True,
118
+ **kwargs,
119
+ )
120
+
121
+ # Get hf_model_class from Config
122
+ has_remote_code = (
123
+ hasattr(config, "auto_map") and convert_rbln_to_hf_model_name(cls.__name__) in config.auto_map
124
+ )
125
+ if has_remote_code:
126
+ class_ref = config.auto_map[convert_rbln_to_hf_model_name(cls.__name__)]
127
+ model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
128
+ elif type(config) in cls._model_mapping.keys():
129
+ model_class = _get_model_class(config, cls._model_mapping)
130
+ else:
131
+ raise ValueError(
132
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
133
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
134
+ )
135
+
136
+ if model_class.__name__ != config.architectures[0]:
137
+ warnings.warn(
138
+ f"`{cls.__name__}.from_pretrained()` is invoking `{convert_hf_to_rbln_model_name(model_class.__name__)}.from_pretrained()`, which does not match the "
139
+ f"expected architecture `RBLN{config.architectures[0]}` from config. This mismatch could cause some operations to not be properly loaded "
140
+ f"from the checkpoint, leading to potential unintended behavior. If this is not intentional, consider calling the "
141
+ f"`from_pretrained()` method directly from the `RBLN{config.architectures[0]}` class instead.",
142
+ UserWarning,
143
+ )
144
+
145
+ return model_class
146
+
147
+ @classmethod
148
+ def get_rbln_model_class_name(cls, pretrained_model_name_or_path, **kwargs):
149
+ """
150
+ Retrieve the path to the compiled model directory for a given RBLN model.
151
+
152
+ Args:
153
+ pretrained_model_name_or_path (str): Identifier of the model.
154
+ kwargs: Additional arguments that match the parameters of `_load_compiled_model_dir`.
155
+
156
+ Returns:
157
+ str: Path to the compiled model directory.
158
+ """
159
+ sig = inspect.signature(RBLNBaseModel._load_compiled_model_dir)
160
+ valid_params = sig.parameters.keys()
161
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
162
+
163
+ model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
164
+ model_id=pretrained_model_name_or_path, **filtered_kwargs
165
+ )
166
+ rbln_config = RBLNConfig.load(model_path_subfolder)
167
+
168
+ return rbln_config.meta["cls"]
169
+
76
170
  @classmethod
77
171
  def from_pretrained(
78
172
  cls,
@@ -21,18 +21,31 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+
24
25
  from transformers.models.auto.modeling_auto import (
26
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
25
27
  MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
28
+ MODEL_FOR_CAUSAL_LM_MAPPING,
26
29
  MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
30
+ MODEL_FOR_CTC_MAPPING,
27
31
  MODEL_FOR_CTC_MAPPING_NAMES,
32
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
28
33
  MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
34
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
29
35
  MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
36
+ MODEL_FOR_MASKED_LM_MAPPING,
30
37
  MODEL_FOR_MASKED_LM_MAPPING_NAMES,
38
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
31
39
  MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
40
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
32
41
  MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
42
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
33
43
  MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
44
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
34
45
  MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
46
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
35
47
  MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
48
+ MODEL_MAPPING,
36
49
  MODEL_MAPPING_NAMES,
37
50
  )
38
51
 
@@ -48,48 +61,60 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
48
61
 
49
62
 
50
63
  class RBLNAutoModel(_BaseAutoModelClass):
51
- _model_mapping = MODEL_MAPPING_NAMES
64
+ _model_mapping = MODEL_MAPPING
65
+ _model_mapping_names = MODEL_MAPPING_NAMES
52
66
 
53
67
 
54
68
  class RBLNAutoModelForCTC(_BaseAutoModelClass):
55
- _model_mapping = MODEL_FOR_CTC_MAPPING_NAMES
69
+ _model_mapping = MODEL_FOR_CTC_MAPPING
70
+ _model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
56
71
 
57
72
 
58
73
  class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
59
- _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
74
+ _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
75
+ _model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
60
76
 
61
77
 
62
78
  class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
63
- _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
79
+ _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
80
+ _model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
64
81
 
65
82
 
66
83
  class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
67
- _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
84
+ _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
85
+ _model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
68
86
 
69
87
 
70
88
  class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
71
- _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
89
+ _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
90
+ _model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
72
91
 
73
92
 
74
93
  class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
75
- _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
94
+ _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
95
+ _model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
76
96
 
77
97
 
78
98
  class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
79
- _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
99
+ _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
100
+ _model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
80
101
 
81
102
 
82
103
  class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
83
- _model_mapping = MODEL_FOR_MASKED_LM_MAPPING_NAMES
104
+ _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
105
+ _model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
84
106
 
85
107
 
86
108
  class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
87
- _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
109
+ _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
110
+ _model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
88
111
 
89
112
 
90
113
  class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
91
- _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
114
+ _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
115
+ _model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
92
116
 
93
117
 
94
118
  class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
95
- _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
119
+ _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
120
+ _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
@@ -24,9 +24,9 @@
24
24
  import inspect
25
25
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
26
26
 
27
- from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
27
+ from transformers import BartForConditionalGeneration, PretrainedConfig
28
28
 
29
- from ....modeling_base import RBLNModel
29
+ from ....modeling import RBLNModel
30
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
31
31
  from ....utils.logging import get_logger
32
32
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
@@ -41,9 +41,6 @@ if TYPE_CHECKING:
41
41
 
42
42
 
43
43
  class RBLNBartModel(RBLNModel):
44
- original_model_class = BartModel
45
- original_config_class = BartConfig
46
-
47
44
  @classmethod
48
45
  def _get_rbln_config(
49
46
  cls,
@@ -82,7 +79,7 @@ class RBLNBartModel(RBLNModel):
82
79
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
83
80
  rbln_model_input_names = cls.rbln_model_input_names
84
81
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
85
- input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
82
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
86
83
  raise ValueError(
87
84
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
88
85
  f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(input_names_order)})"
@@ -25,9 +25,9 @@ import inspect
25
25
  import logging
26
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
- from transformers import BertConfig, BertModel, PretrainedConfig
28
+ from transformers import PretrainedConfig
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -38,9 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNBertModel(RBLNModel):
41
- original_model_class = BertModel
42
- original_config_class = BertConfig
43
-
44
41
  @classmethod
45
42
  def _get_rbln_config(
46
43
  cls,
@@ -75,7 +72,7 @@ class RBLNBertModel(RBLNModel):
75
72
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
76
73
  rbln_model_input_names = cls.rbln_model_input_names
77
74
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
78
- input_names_order = inspect.signature(cls.original_model_class.forward).parameters.keys()
75
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
79
76
  raise ValueError(
80
77
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
81
78
  f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(input_names_order)})"
@@ -26,19 +26,17 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
26
26
 
27
27
  import torch
28
28
  from transformers import (
29
- AutoConfig,
30
- AutoModel,
31
29
  CLIPTextConfig,
32
30
  CLIPTextModel,
33
- CLIPTextModelWithProjection,
34
31
  CLIPVisionConfig,
35
32
  CLIPVisionModel,
36
33
  )
37
34
  from transformers.modeling_outputs import BaseModelOutputWithPooling
38
35
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput
39
36
 
40
- from ....modeling_base import RBLNModel
37
+ from ....modeling import RBLNModel
41
38
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
39
+ from ....modeling_diffusers import RBLNDiffusionMixin
42
40
 
43
41
 
44
42
  logger = logging.getLogger(__name__)
@@ -58,24 +56,14 @@ class _TextEncoder(torch.nn.Module):
58
56
 
59
57
 
60
58
  class RBLNCLIPTextModel(RBLNModel):
61
- original_model_class = CLIPTextModel
62
- original_config_class = CLIPTextConfig
63
-
64
- @classmethod
65
- def from_pretrained(cls, *args, **kwargs):
66
- configtmp = AutoConfig.from_pretrained
67
- modeltmp = AutoModel.from_pretrained
68
- AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
69
- AutoModel.from_pretrained = cls.original_model_class.from_pretrained
70
- rt = super().from_pretrained(*args, **kwargs)
71
- AutoConfig.from_pretrained = configtmp
72
- AutoModel.from_pretrained = modeltmp
73
- return rt
74
-
75
59
  @classmethod
76
60
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
77
61
  return _TextEncoder(model).eval()
78
62
 
63
+ @classmethod
64
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
65
+ return rbln_config
66
+
79
67
  @classmethod
80
68
  def _get_rbln_config(
81
69
  cls,
@@ -119,7 +107,7 @@ class RBLNCLIPTextModel(RBLNModel):
119
107
 
120
108
 
121
109
  class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
122
- original_model_class = CLIPTextModelWithProjection
110
+ pass
123
111
 
124
112
 
125
113
  class _VisionEncoder(torch.nn.Module):
@@ -133,20 +121,6 @@ class _VisionEncoder(torch.nn.Module):
133
121
 
134
122
 
135
123
  class RBLNCLIPVisionModel(RBLNModel):
136
- original_model_class = CLIPVisionModel
137
- original_config_class = CLIPVisionConfig
138
-
139
- @classmethod
140
- def from_pretrained(cls, *args, **kwargs):
141
- configtmp = AutoConfig.from_pretrained
142
- modeltmp = AutoModel.from_pretrained
143
- AutoConfig.from_pretrained = cls.original_config_class.from_pretrained
144
- AutoModel.from_pretrained = cls.original_model_class.from_pretrained
145
- rt = super().from_pretrained(*args, **kwargs)
146
- AutoConfig.from_pretrained = configtmp
147
- AutoModel.from_pretrained = modeltmp
148
- return rt
149
-
150
124
  @classmethod
151
125
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
152
126
  return _VisionEncoder(model).eval()
@@ -155,7 +129,7 @@ class RBLNCLIPVisionModel(RBLNModel):
155
129
  def _get_rbln_config(
156
130
  cls,
157
131
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
158
- model_config: "CLIPTextConfig",
132
+ model_config: "CLIPVisionConfig",
159
133
  rbln_kwargs: Dict[str, Any] = {},
160
134
  ) -> RBLNConfig:
161
135
  rbln_batch_size = rbln_kwargs.get("batch_size", 1)
@@ -22,12 +22,7 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  from .decoderonly_architecture import (
25
- DecoderOnlyAttention,
26
- DecoderOnlyDecoderLayer,
27
- DecoderOnlyModel,
28
25
  DecoderOnlyWrapper,
29
- DynamicNTKScalingRotaryEmbedding,
30
- LinearScalingRotaryEmbedding,
31
26
  RotaryEmbedding,
32
27
  apply_rotary_pos_emb,
33
28
  rotate_half,