optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING
25
25
 
26
26
  from transformers.utils import _LazyModule
27
27
 
28
+ from .__version__ import __version__
28
29
  from .utils import check_version_compats
29
30
 
30
31
 
@@ -34,11 +35,10 @@ _import_structure = {
34
35
  "RBLNBertForQuestionAnswering",
35
36
  "RBLNDistilBertForQuestionAnswering",
36
37
  "RBLNResNetForImageClassification",
37
- "RBLNT5ForConditionalGeneration",
38
- "RBLNBartForConditionalGeneration",
39
38
  "RBLNXLMRobertaForSequenceClassification",
40
39
  "RBLNRobertaForSequenceClassification",
41
40
  "RBLNRobertaForMaskedLM",
41
+ "RBLNViTForImageClassification"
42
42
  ],
43
43
  "modeling_base": [
44
44
  "RBLNBaseModel",
@@ -49,18 +49,36 @@ _import_structure = {
49
49
  "RBLNModelForSequenceClassification",
50
50
  "RBLNModelForMaskedLM",
51
51
  ],
52
- "modeling_seq2seq": [
53
- "RBLNModelForSeq2SeqLM",
54
- ],
55
52
  "transformers": [
56
53
  "BatchTextIteratorStreamer",
54
+ "RBLNAutoModel",
55
+ "RBLNAutoModelForAudioClassification",
56
+ "RBLNAutoModelForCausalLM",
57
+ "RBLNAutoModelForCTC",
58
+ "RBLNAutoModelForDepthEstimation",
59
+ "RBLNAutoModelForImageClassification",
60
+ "RBLNAutoModelForMaskedLM",
61
+ "RBLNAutoModelForQuestionAnswering",
62
+ "RBLNAutoModelForSeq2SeqLM",
63
+ "RBLNAutoModelForSequenceClassification",
64
+ "RBLNAutoModelForSpeechSeq2Seq",
65
+ "RBLNAutoModelForVision2Seq",
66
+ "RBLNBartForConditionalGeneration",
67
+ "RBLNBartModel",
68
+ "RBLNBertModel",
57
69
  "RBLNCLIPTextModel",
58
70
  "RBLNCLIPTextModelWithProjection",
71
+ "RBLNCLIPVisionModel",
59
72
  "RBLNDPTForDepthEstimation",
73
+ "RBLNExaoneForCausalLM",
60
74
  "RBLNGemmaForCausalLM",
61
75
  "RBLNGPT2LMHeadModel",
76
+ "RBLNQwen2ForCausalLM",
62
77
  "RBLNWav2Vec2ForCTC",
63
78
  "RBLNLlamaForCausalLM",
79
+ "RBLNT5ForConditionalGeneration",
80
+ "RBLNPhiForCausalLM",
81
+ "RBLNLlavaNextForConditionalGeneration",
64
82
  "RBLNMidmLMHeadModel",
65
83
  "RBLNMistralForCausalLM",
66
84
  "RBLNWhisperForConditionalGeneration",
@@ -80,7 +98,7 @@ _import_structure = {
80
98
  "RBLNStableDiffusionXLControlNetPipeline",
81
99
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
82
100
  ],
83
- "modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
101
+ "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
84
102
  }
85
103
 
86
104
  if TYPE_CHECKING:
@@ -100,12 +118,12 @@ if TYPE_CHECKING:
100
118
  )
101
119
  from .modeling_alias import (
102
120
  RBLNASTForAudioClassification,
103
- RBLNBartForConditionalGeneration,
104
121
  RBLNBertForQuestionAnswering,
105
122
  RBLNResNetForImageClassification,
106
123
  RBLNRobertaForMaskedLM,
107
124
  RBLNRobertaForSequenceClassification,
108
125
  RBLNT5ForConditionalGeneration,
126
+ RBLNViTForImageClassification,
109
127
  RBLNXLMRobertaForSequenceClassification,
110
128
  )
111
129
  from .modeling_base import (
@@ -117,18 +135,38 @@ if TYPE_CHECKING:
117
135
  RBLNModelForQuestionAnswering,
118
136
  RBLNModelForSequenceClassification,
119
137
  )
120
- from .modeling_config import RBLNConfig, RBLNRuntimeConfig
121
- from .modeling_seq2seq import RBLNModelForSeq2SeqLM
138
+ from .modeling_config import RBLNCompileConfig, RBLNConfig
122
139
  from .transformers import (
123
140
  BatchTextIteratorStreamer,
141
+ RBLNAutoModel,
142
+ RBLNAutoModelForAudioClassification,
143
+ RBLNAutoModelForCausalLM,
144
+ RBLNAutoModelForCTC,
145
+ RBLNAutoModelForDepthEstimation,
146
+ RBLNAutoModelForImageClassification,
147
+ RBLNAutoModelForMaskedLM,
148
+ RBLNAutoModelForQuestionAnswering,
149
+ RBLNAutoModelForSeq2SeqLM,
150
+ RBLNAutoModelForSequenceClassification,
151
+ RBLNAutoModelForSpeechSeq2Seq,
152
+ RBLNAutoModelForVision2Seq,
153
+ RBLNBartForConditionalGeneration,
154
+ RBLNBartModel,
155
+ RBLNBertModel,
124
156
  RBLNCLIPTextModel,
125
157
  RBLNCLIPTextModelWithProjection,
158
+ RBLNCLIPVisionModel,
126
159
  RBLNDPTForDepthEstimation,
160
+ RBLNExaoneForCausalLM,
127
161
  RBLNGemmaForCausalLM,
128
162
  RBLNGPT2LMHeadModel,
129
163
  RBLNLlamaForCausalLM,
164
+ RBLNLlavaNextForConditionalGeneration,
130
165
  RBLNMidmLMHeadModel,
131
166
  RBLNMistralForCausalLM,
167
+ RBLNPhiForCausalLM,
168
+ RBLNQwen2ForCausalLM,
169
+ RBLNT5ForConditionalGeneration,
132
170
  RBLNWav2Vec2ForCTC,
133
171
  RBLNWhisperForConditionalGeneration,
134
172
  RBLNXLMRobertaModel,
@@ -1 +1 @@
1
- __version__ = '0.1.9'
1
+ __version__ = '0.1.12'
@@ -23,7 +23,7 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
27
27
 
28
28
  import rebel
29
29
  import torch # noqa: I001
@@ -34,7 +34,7 @@ from optimum.exporters import TasksManager
34
34
  from transformers import AutoConfig, AutoModel, PretrainedConfig
35
35
 
36
36
  from ...modeling_base import RBLNModel
37
- from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
37
+ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
38
38
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
39
 
40
40
 
@@ -58,15 +58,12 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
58
58
 
59
59
 
60
60
  class RBLNAutoencoderKL(RBLNModel):
61
- model_type = "rbln_model"
62
61
  config_name = "config.json"
63
- auto_model_class = AutoModel # feature extraction
64
62
 
65
63
  def __post_init__(self, **kwargs):
66
- self.dtype = torch.float32
67
-
68
- self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
64
+ super().__post_init__(**kwargs)
69
65
 
66
+ self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
70
67
  if self.rbln_use_encode:
71
68
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
72
69
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
@@ -81,20 +78,20 @@ class RBLNAutoencoderKL(RBLNModel):
81
78
  encoder_model.eval()
82
79
  decoder_model.eval()
83
80
 
84
- enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
85
- dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
81
+ enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
82
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
86
83
 
87
- return enc_compiled_model, dec_compiled_model
84
+ return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
88
85
 
89
86
  def compile_text2img():
90
87
  decoder_model = _VAEDecoder(model)
91
88
  decoder_model.eval()
92
89
 
93
- dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
90
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
94
91
 
95
92
  return dec_compiled_model
96
93
 
97
- if rbln_config.meta.get("rbln_use_encode", False):
94
+ if rbln_config.model_cfg.get("use_encode", False):
98
95
  return compile_img2img()
99
96
  else:
100
97
  return compile_text2img()
@@ -133,23 +130,23 @@ class RBLNAutoencoderKL(RBLNModel):
133
130
  cls,
134
131
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
135
132
  model_config: "PretrainedConfig",
136
- rbln_unet_sample_size: Optional[int] = None,
137
- rbln_img_width: Optional[int] = None,
138
- rbln_img_height: Optional[int] = None,
139
- rbln_batch_size: Optional[int] = None,
140
- rbln_use_encode: Optional[bool] = None,
141
- rbln_vae_scale_factor: Optional[int] = None,
133
+ rbln_kwargs: Dict[str, Any] = {},
142
134
  ) -> RBLNConfig:
143
- meta = {}
135
+ rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
136
+ rbln_img_width = rbln_kwargs.get("img_width", None)
137
+ rbln_img_height = rbln_kwargs.get("img_height", None)
138
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
139
+ rbln_use_encode = rbln_kwargs.get("use_encode", None)
140
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
141
+
144
142
  if rbln_batch_size is None:
145
143
  rbln_batch_size = 1
146
144
 
147
- meta["rbln_use_encode"] = rbln_use_encode
148
- meta["rbln_batch_size"] = rbln_batch_size
145
+ model_cfg = {}
149
146
 
150
147
  if rbln_use_encode:
151
- meta["rbln_img_width"] = rbln_img_width
152
- meta["rbln_img_height"] = rbln_img_height
148
+ model_cfg["img_width"] = rbln_img_width
149
+ model_cfg["img_height"] = rbln_img_height
153
150
 
154
151
  vae_enc_input_info = [
155
152
  ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
@@ -167,20 +164,23 @@ class RBLNAutoencoderKL(RBLNModel):
167
164
  )
168
165
  ]
169
166
 
170
- enc_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
171
- dec_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
167
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
168
+ dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
172
169
 
173
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
174
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
175
- _rbln_meta=meta,
170
+ compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
171
+ rbln_config = RBLNConfig(
172
+ rbln_cls=cls.__name__,
173
+ compile_cfgs=compile_cfgs,
174
+ rbln_kwargs=rbln_kwargs,
176
175
  )
176
+ rbln_config.model_cfg.update(model_cfg)
177
177
  return rbln_config
178
178
 
179
179
  if rbln_unet_sample_size is None:
180
180
  rbln_unet_sample_size = 64
181
181
 
182
- meta["rbln_unet_sample_size"] = rbln_unet_sample_size
183
- vae_config = RBLNRuntimeConfig(
182
+ model_cfg["unet_sample_size"] = rbln_unet_sample_size
183
+ vae_config = RBLNCompileConfig(
184
184
  input_info=[
185
185
  (
186
186
  "z",
@@ -194,7 +194,12 @@ class RBLNAutoencoderKL(RBLNModel):
194
194
  )
195
195
  ],
196
196
  )
197
- rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
197
+ rbln_config = RBLNConfig(
198
+ rbln_cls=cls.__name__,
199
+ compile_cfgs=[vae_config],
200
+ rbln_kwargs=rbln_kwargs,
201
+ )
202
+ rbln_config.model_cfg.update(model_cfg)
198
203
  return rbln_config
199
204
 
200
205
  @classmethod
@@ -23,7 +23,7 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Dict, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
28
  import torch
29
29
  from diffusers import ControlNetModel
@@ -31,7 +31,7 @@ from optimum.exporters import TasksManager
31
31
  from transformers import AutoConfig, AutoModel, PretrainedConfig
32
32
 
33
33
  from ...modeling_base import RBLNModel
34
- from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
34
+ from ...modeling_config import RBLNCompileConfig, RBLNConfig
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -105,13 +105,10 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
105
105
 
106
106
 
107
107
  class RBLNControlNetModel(RBLNModel):
108
- model_type = "rbln_model"
109
- auto_model_class = AutoModel # feature extraction
110
-
111
108
  def __post_init__(self, **kwargs):
112
- self.dtype = torch.float32
109
+ super().__post_init__(**kwargs)
113
110
  self.use_encoder_hidden_states = any(
114
- item[0] == "encoder_hidden_states" for item in self.rbln_config["compiled_model"][0].input_info
111
+ item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
115
112
  )
116
113
 
117
114
  @classmethod
@@ -121,9 +118,6 @@ class RBLNControlNetModel(RBLNModel):
121
118
  model_name_or_path: Union[str, Path],
122
119
  **kwargs,
123
120
  ):
124
- if "subfolder" in kwargs:
125
- del kwargs["subfolder"]
126
-
127
121
  return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
128
122
 
129
123
  tasktmp = TasksManager.get_model_from_task
@@ -155,14 +149,14 @@ class RBLNControlNetModel(RBLNModel):
155
149
  cls,
156
150
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
157
151
  model_config: "PretrainedConfig",
158
- rbln_max_seq_len: Optional[int] = None,
159
- rbln_text_model_hidden_size: Optional[int] = None,
160
- rbln_batch_size: Optional[int] = None,
161
- rbln_img_width: Optional[int] = None,
162
- rbln_img_height: Optional[int] = None,
163
- rbln_vae_scale_factor: Optional[int] = None,
152
+ rbln_kwargs: Dict[str, Any] = {},
164
153
  ) -> RBLNConfig:
165
- meta = {"type": "controlnet"}
154
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
155
+ rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
156
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
157
+ rbln_img_width = rbln_kwargs.get("img_width", None)
158
+ rbln_img_height = rbln_kwargs.get("img_height", None)
159
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
166
160
 
167
161
  if rbln_batch_size is None:
168
162
  rbln_batch_size = 1
@@ -170,28 +164,29 @@ class RBLNControlNetModel(RBLNModel):
170
164
  if rbln_max_seq_len is None:
171
165
  rbln_max_seq_len = 77
172
166
 
167
+ if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
168
+ raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
169
+
173
170
  input_width = rbln_img_width // rbln_vae_scale_factor
174
171
  input_height = rbln_img_height // rbln_vae_scale_factor
175
172
 
176
- rbln_runtime_config = RBLNRuntimeConfig(
177
- input_info=[
178
- (
179
- "sample",
180
- [
181
- rbln_batch_size,
182
- model_config.in_channels,
183
- input_height,
184
- input_width,
185
- ],
186
- "float32",
187
- ),
188
- ("timestep", [], "float32"),
189
- ],
190
- batch_size=rbln_batch_size,
191
- )
173
+ input_info = [
174
+ (
175
+ "sample",
176
+ [
177
+ rbln_batch_size,
178
+ model_config.in_channels,
179
+ input_height,
180
+ input_width,
181
+ ],
182
+ "float32",
183
+ ),
184
+ ("timestep", [], "float32"),
185
+ ]
186
+
192
187
  use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
193
188
  if use_encoder_hidden_states:
194
- rbln_runtime_config.input_info.append(
189
+ input_info.append(
195
190
  (
196
191
  "encoder_hidden_states",
197
192
  [
@@ -202,19 +197,34 @@ class RBLNControlNetModel(RBLNModel):
202
197
  "float32",
203
198
  )
204
199
  )
205
- rbln_runtime_config.input_info.append(
206
- ("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32")
207
- )
208
- rbln_runtime_config.input_info.append(("conditioning_scale", [], "float32"))
200
+
201
+ input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
202
+ input_info.append(("conditioning_scale", [], "float32"))
203
+
209
204
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
210
205
  if rbln_text_model_hidden_size is None:
211
206
  rbln_text_model_hidden_size = 768
212
- rbln_runtime_config.input_info.append(
213
- ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
214
- )
215
- rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
207
+ input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
208
+ input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
209
+
210
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
211
+
212
+ rbln_config = RBLNConfig(
213
+ rbln_cls=cls.__name__,
214
+ compile_cfgs=[rbln_compile_config],
215
+ rbln_kwargs=rbln_kwargs,
216
+ )
217
+
218
+ rbln_config.model_cfg.update(
219
+ {
220
+ "max_seq_len": rbln_max_seq_len,
221
+ "batch_size": rbln_batch_size,
222
+ "img_width": rbln_img_width,
223
+ "img_height": rbln_img_height,
224
+ "vae_scale_factor": rbln_vae_scale_factor,
225
+ }
226
+ )
216
227
 
217
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
218
228
  return rbln_config
219
229
 
220
230
  def forward(
@@ -32,7 +32,7 @@ from optimum.exporters import TasksManager
32
32
  from transformers import AutoConfig, AutoModel, PretrainedConfig
33
33
 
34
34
  from ...modeling_base import RBLNModel
35
- from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
35
+ from ...modeling_config import RBLNCompileConfig, RBLNConfig
36
36
 
37
37
 
38
38
  if TYPE_CHECKING:
@@ -126,12 +126,9 @@ class _UNet_SDXL(torch.nn.Module):
126
126
 
127
127
 
128
128
  class RBLNUNet2DConditionModel(RBLNModel):
129
- model_type = "rbln_model"
130
- auto_model_class = AutoModel # feature extraction
131
-
132
129
  def __post_init__(self, **kwargs):
133
- self.dtype = torch.float32
134
- self.in_features = self.rbln_config.meta.get("in_features", None)
130
+ super().__post_init__(**kwargs)
131
+ self.in_features = self.rbln_config.model_cfg.get("in_features", None)
135
132
  if self.in_features is not None:
136
133
 
137
134
  @dataclass
@@ -183,31 +180,31 @@ class RBLNUNet2DConditionModel(RBLNModel):
183
180
  cls,
184
181
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
185
182
  model_config: "PretrainedConfig",
186
- rbln_max_seq_len: Optional[int] = None,
187
- rbln_text_model_hidden_size: Optional[int] = None,
188
- rbln_batch_size: Optional[int] = None,
189
- rbln_in_features: Optional[int] = None,
190
- rbln_use_encode: Optional[bool] = None,
191
- rbln_img_width: Optional[int] = None,
192
- rbln_img_height: Optional[int] = None,
193
- rbln_vae_scale_factor: Optional[int] = None,
194
- rbln_is_controlnet: Optional[bool] = None,
183
+ rbln_kwargs: Dict[str, Any] = {},
195
184
  ) -> RBLNConfig:
196
- meta = {"type": "unet"}
197
- if rbln_batch_size is None:
198
- rbln_batch_size = 1
185
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
186
+ rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
187
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
+ rbln_in_features = rbln_kwargs.get("in_features", None)
189
+ rbln_use_encode = rbln_kwargs.get("use_encode", None)
190
+ rbln_img_width = rbln_kwargs.get("img_width", None)
191
+ rbln_img_height = rbln_kwargs.get("img_height", None)
192
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
193
+ rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
199
194
 
200
195
  if rbln_max_seq_len is None:
201
196
  rbln_max_seq_len = 77
202
-
203
- meta["rbln_use_encode"] = rbln_use_encode
197
+ if rbln_batch_size is None:
198
+ rbln_batch_size = 1
204
199
 
205
200
  if rbln_use_encode:
206
- # FIXME :: robust img shape getter
201
+ if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
202
+ raise ValueError(
203
+ "rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
204
+ )
207
205
  input_width = rbln_img_width // rbln_vae_scale_factor
208
206
  input_height = rbln_img_height // rbln_vae_scale_factor
209
207
  else:
210
- # FIXME :: model_config.sample_size can be tuple or list
211
208
  input_width, input_height = model_config.sample_size, model_config.sample_size
212
209
 
213
210
  input_info = [
@@ -232,6 +229,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
232
229
  "float32",
233
230
  ),
234
231
  ]
232
+
235
233
  if rbln_is_controlnet:
236
234
  if len(model_config.block_out_channels) > 0:
237
235
  input_info.extend(
@@ -304,24 +302,35 @@ class RBLNUNet2DConditionModel(RBLNModel):
304
302
  )
305
303
  )
306
304
 
307
- rbln_runtime_config = RBLNRuntimeConfig(
308
- input_info=input_info,
309
- batch_size=rbln_batch_size,
310
- )
305
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
311
306
 
312
307
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
313
- # In case of sdxl
314
308
  if rbln_text_model_hidden_size is None:
315
309
  rbln_text_model_hidden_size = 768
316
310
  if rbln_in_features is None:
317
311
  rbln_in_features = model_config.projection_class_embeddings_input_dim
318
- meta["in_features"] = rbln_in_features
319
- rbln_runtime_config.input_info.append(
312
+ rbln_compile_config.input_info.append(
320
313
  ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
321
314
  )
322
- rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
315
+ rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
316
+
317
+ rbln_config = RBLNConfig(
318
+ rbln_cls=cls.__name__,
319
+ compile_cfgs=[rbln_compile_config],
320
+ rbln_kwargs=rbln_kwargs,
321
+ )
322
+
323
+ rbln_config.model_cfg.update(
324
+ {
325
+ "max_seq_len": rbln_max_seq_len,
326
+ "batch_size": rbln_batch_size,
327
+ "use_encode": rbln_use_encode,
328
+ }
329
+ )
330
+
331
+ if rbln_in_features is not None:
332
+ rbln_config.model_cfg["in_features"] = rbln_in_features
323
333
 
324
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
325
334
  return rbln_config
326
335
 
327
336
  def forward(
@@ -102,6 +102,10 @@ class RBLNMultiControlNetModel(RBLNModel):
102
102
  real_save_path = save_directory + suffix
103
103
  model.save_pretrained(real_save_path)
104
104
 
105
+ @classmethod
106
+ def _get_rbln_config(cls, **rbln_config_kwargs):
107
+ pass
108
+
105
109
  def forward(
106
110
  self,
107
111
  sample: torch.FloatTensor,