optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING
25
25
 
26
26
  from transformers.utils import _LazyModule
27
27
 
28
+ from .__version__ import __version__
28
29
  from .utils import check_version_compats
29
30
 
30
31
 
@@ -54,13 +55,30 @@ _import_structure = {
54
55
  ],
55
56
  "transformers": [
56
57
  "BatchTextIteratorStreamer",
58
+ "RBLNAutoModel",
59
+ "RBLNAutoModelForAudioClassification",
60
+ "RBLNAutoModelForCausalLM",
61
+ "RBLNAutoModelForCTC",
62
+ "RBLNAutoModelForDepthEstimation",
63
+ "RBLNAutoModelForImageClassification",
64
+ "RBLNAutoModelForMaskedLM",
65
+ "RBLNAutoModelForQuestionAnswering",
66
+ "RBLNAutoModelForSeq2SeqLM",
67
+ "RBLNAutoModelForSequenceClassification",
68
+ "RBLNAutoModelForSpeechSeq2Seq",
69
+ "RBLNAutoModelForVision2Seq",
70
+ "RBLNBartModel",
71
+ "RBLNBertModel",
57
72
  "RBLNCLIPTextModel",
58
73
  "RBLNCLIPTextModelWithProjection",
74
+ "RBLNCLIPVisionModel",
59
75
  "RBLNDPTForDepthEstimation",
60
76
  "RBLNGemmaForCausalLM",
61
77
  "RBLNGPT2LMHeadModel",
62
78
  "RBLNWav2Vec2ForCTC",
63
79
  "RBLNLlamaForCausalLM",
80
+ "RBLNPhiForCausalLM",
81
+ "RBLNLlavaNextForConditionalGeneration",
64
82
  "RBLNMidmLMHeadModel",
65
83
  "RBLNMistralForCausalLM",
66
84
  "RBLNWhisperForConditionalGeneration",
@@ -80,7 +98,7 @@ _import_structure = {
80
98
  "RBLNStableDiffusionXLControlNetPipeline",
81
99
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
82
100
  ],
83
- "modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
101
+ "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
84
102
  }
85
103
 
86
104
  if TYPE_CHECKING:
@@ -117,18 +135,35 @@ if TYPE_CHECKING:
117
135
  RBLNModelForQuestionAnswering,
118
136
  RBLNModelForSequenceClassification,
119
137
  )
120
- from .modeling_config import RBLNConfig, RBLNRuntimeConfig
138
+ from .modeling_config import RBLNCompileConfig, RBLNConfig
121
139
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
122
140
  from .transformers import (
123
141
  BatchTextIteratorStreamer,
142
+ RBLNAutoModel,
143
+ RBLNAutoModelForAudioClassification,
144
+ RBLNAutoModelForCausalLM,
145
+ RBLNAutoModelForCTC,
146
+ RBLNAutoModelForDepthEstimation,
147
+ RBLNAutoModelForImageClassification,
148
+ RBLNAutoModelForMaskedLM,
149
+ RBLNAutoModelForQuestionAnswering,
150
+ RBLNAutoModelForSeq2SeqLM,
151
+ RBLNAutoModelForSequenceClassification,
152
+ RBLNAutoModelForSpeechSeq2Seq,
153
+ RBLNAutoModelForVision2Seq,
154
+ RBLNBartModel,
155
+ RBLNBertModel,
124
156
  RBLNCLIPTextModel,
125
157
  RBLNCLIPTextModelWithProjection,
158
+ RBLNCLIPVisionModel,
126
159
  RBLNDPTForDepthEstimation,
127
160
  RBLNGemmaForCausalLM,
128
161
  RBLNGPT2LMHeadModel,
129
162
  RBLNLlamaForCausalLM,
163
+ RBLNLlavaNextForConditionalGeneration,
130
164
  RBLNMidmLMHeadModel,
131
165
  RBLNMistralForCausalLM,
166
+ RBLNPhiForCausalLM,
132
167
  RBLNWav2Vec2ForCTC,
133
168
  RBLNWhisperForConditionalGeneration,
134
169
  RBLNXLMRobertaModel,
@@ -1 +1 @@
1
- __version__ = '0.1.9'
1
+ __version__ = '0.1.11'
@@ -23,7 +23,7 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
27
27
 
28
28
  import rebel
29
29
  import torch # noqa: I001
@@ -34,7 +34,7 @@ from optimum.exporters import TasksManager
34
34
  from transformers import AutoConfig, AutoModel, PretrainedConfig
35
35
 
36
36
  from ...modeling_base import RBLNModel
37
- from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
37
+ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
38
38
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
39
 
40
40
 
@@ -63,10 +63,9 @@ class RBLNAutoencoderKL(RBLNModel):
63
63
  auto_model_class = AutoModel # feature extraction
64
64
 
65
65
  def __post_init__(self, **kwargs):
66
- self.dtype = torch.float32
67
-
68
- self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
66
+ super().__post_init__(**kwargs)
69
67
 
68
+ self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
70
69
  if self.rbln_use_encode:
71
70
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
72
71
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
@@ -81,20 +80,20 @@ class RBLNAutoencoderKL(RBLNModel):
81
80
  encoder_model.eval()
82
81
  decoder_model.eval()
83
82
 
84
- enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
85
- dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
83
+ enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
84
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
86
85
 
87
- return enc_compiled_model, dec_compiled_model
86
+ return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
88
87
 
89
88
  def compile_text2img():
90
89
  decoder_model = _VAEDecoder(model)
91
90
  decoder_model.eval()
92
91
 
93
- dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
92
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
94
93
 
95
94
  return dec_compiled_model
96
95
 
97
- if rbln_config.meta.get("rbln_use_encode", False):
96
+ if rbln_config.model_cfg.get("use_encode", False):
98
97
  return compile_img2img()
99
98
  else:
100
99
  return compile_text2img()
@@ -133,23 +132,23 @@ class RBLNAutoencoderKL(RBLNModel):
133
132
  cls,
134
133
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
135
134
  model_config: "PretrainedConfig",
136
- rbln_unet_sample_size: Optional[int] = None,
137
- rbln_img_width: Optional[int] = None,
138
- rbln_img_height: Optional[int] = None,
139
- rbln_batch_size: Optional[int] = None,
140
- rbln_use_encode: Optional[bool] = None,
141
- rbln_vae_scale_factor: Optional[int] = None,
135
+ rbln_kwargs: Dict[str, Any] = {},
142
136
  ) -> RBLNConfig:
143
- meta = {}
137
+ rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
138
+ rbln_img_width = rbln_kwargs.get("img_width", None)
139
+ rbln_img_height = rbln_kwargs.get("img_height", None)
140
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
141
+ rbln_use_encode = rbln_kwargs.get("use_encode", None)
142
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
143
+
144
144
  if rbln_batch_size is None:
145
145
  rbln_batch_size = 1
146
146
 
147
- meta["rbln_use_encode"] = rbln_use_encode
148
- meta["rbln_batch_size"] = rbln_batch_size
147
+ model_cfg = {}
149
148
 
150
149
  if rbln_use_encode:
151
- meta["rbln_img_width"] = rbln_img_width
152
- meta["rbln_img_height"] = rbln_img_height
150
+ model_cfg["img_width"] = rbln_img_width
151
+ model_cfg["img_height"] = rbln_img_height
153
152
 
154
153
  vae_enc_input_info = [
155
154
  ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
@@ -167,20 +166,23 @@ class RBLNAutoencoderKL(RBLNModel):
167
166
  )
168
167
  ]
169
168
 
170
- enc_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
171
- dec_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
169
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
170
+ dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
172
171
 
173
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
174
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
175
- _rbln_meta=meta,
172
+ compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
173
+ rbln_config = RBLNConfig(
174
+ rbln_cls=cls.__name__,
175
+ compile_cfgs=compile_cfgs,
176
+ rbln_kwargs=rbln_kwargs,
176
177
  )
178
+ rbln_config.model_cfg.update(model_cfg)
177
179
  return rbln_config
178
180
 
179
181
  if rbln_unet_sample_size is None:
180
182
  rbln_unet_sample_size = 64
181
183
 
182
- meta["rbln_unet_sample_size"] = rbln_unet_sample_size
183
- vae_config = RBLNRuntimeConfig(
184
+ model_cfg["unet_sample_size"] = rbln_unet_sample_size
185
+ vae_config = RBLNCompileConfig(
184
186
  input_info=[
185
187
  (
186
188
  "z",
@@ -194,7 +196,12 @@ class RBLNAutoencoderKL(RBLNModel):
194
196
  )
195
197
  ],
196
198
  )
197
- rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
199
+ rbln_config = RBLNConfig(
200
+ rbln_cls=cls.__name__,
201
+ compile_cfgs=[vae_config],
202
+ rbln_kwargs=rbln_kwargs,
203
+ )
204
+ rbln_config.model_cfg.update(model_cfg)
198
205
  return rbln_config
199
206
 
200
207
  @classmethod
@@ -23,7 +23,7 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Dict, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
28
  import torch
29
29
  from diffusers import ControlNetModel
@@ -31,7 +31,7 @@ from optimum.exporters import TasksManager
31
31
  from transformers import AutoConfig, AutoModel, PretrainedConfig
32
32
 
33
33
  from ...modeling_base import RBLNModel
34
- from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
34
+ from ...modeling_config import RBLNCompileConfig, RBLNConfig
35
35
 
36
36
 
37
37
  if TYPE_CHECKING:
@@ -109,21 +109,21 @@ class RBLNControlNetModel(RBLNModel):
109
109
  auto_model_class = AutoModel # feature extraction
110
110
 
111
111
  def __post_init__(self, **kwargs):
112
- self.dtype = torch.float32
112
+ super().__post_init__(**kwargs)
113
113
  self.use_encoder_hidden_states = any(
114
- item[0] == "encoder_hidden_states" for item in self.rbln_config["compiled_model"][0].input_info
114
+ item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
115
115
  )
116
116
 
117
117
  @classmethod
118
118
  def from_pretrained(cls, *args, **kwargs):
119
+ if "subfolder" in kwargs:
120
+ del kwargs["subfolder"]
121
+
119
122
  def get_model_from_task(
120
123
  task: str,
121
124
  model_name_or_path: Union[str, Path],
122
125
  **kwargs,
123
126
  ):
124
- if "subfolder" in kwargs:
125
- del kwargs["subfolder"]
126
-
127
127
  return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
128
128
 
129
129
  tasktmp = TasksManager.get_model_from_task
@@ -155,14 +155,14 @@ class RBLNControlNetModel(RBLNModel):
155
155
  cls,
156
156
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
157
157
  model_config: "PretrainedConfig",
158
- rbln_max_seq_len: Optional[int] = None,
159
- rbln_text_model_hidden_size: Optional[int] = None,
160
- rbln_batch_size: Optional[int] = None,
161
- rbln_img_width: Optional[int] = None,
162
- rbln_img_height: Optional[int] = None,
163
- rbln_vae_scale_factor: Optional[int] = None,
158
+ rbln_kwargs: Dict[str, Any] = {},
164
159
  ) -> RBLNConfig:
165
- meta = {"type": "controlnet"}
160
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
161
+ rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
162
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
163
+ rbln_img_width = rbln_kwargs.get("img_width", None)
164
+ rbln_img_height = rbln_kwargs.get("img_height", None)
165
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
166
166
 
167
167
  if rbln_batch_size is None:
168
168
  rbln_batch_size = 1
@@ -170,28 +170,29 @@ class RBLNControlNetModel(RBLNModel):
170
170
  if rbln_max_seq_len is None:
171
171
  rbln_max_seq_len = 77
172
172
 
173
+ if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
174
+ raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
175
+
173
176
  input_width = rbln_img_width // rbln_vae_scale_factor
174
177
  input_height = rbln_img_height // rbln_vae_scale_factor
175
178
 
176
- rbln_runtime_config = RBLNRuntimeConfig(
177
- input_info=[
178
- (
179
- "sample",
180
- [
181
- rbln_batch_size,
182
- model_config.in_channels,
183
- input_height,
184
- input_width,
185
- ],
186
- "float32",
187
- ),
188
- ("timestep", [], "float32"),
189
- ],
190
- batch_size=rbln_batch_size,
191
- )
179
+ input_info = [
180
+ (
181
+ "sample",
182
+ [
183
+ rbln_batch_size,
184
+ model_config.in_channels,
185
+ input_height,
186
+ input_width,
187
+ ],
188
+ "float32",
189
+ ),
190
+ ("timestep", [], "float32"),
191
+ ]
192
+
192
193
  use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
193
194
  if use_encoder_hidden_states:
194
- rbln_runtime_config.input_info.append(
195
+ input_info.append(
195
196
  (
196
197
  "encoder_hidden_states",
197
198
  [
@@ -202,19 +203,34 @@ class RBLNControlNetModel(RBLNModel):
202
203
  "float32",
203
204
  )
204
205
  )
205
- rbln_runtime_config.input_info.append(
206
- ("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32")
207
- )
208
- rbln_runtime_config.input_info.append(("conditioning_scale", [], "float32"))
206
+
207
+ input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
208
+ input_info.append(("conditioning_scale", [], "float32"))
209
+
209
210
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
210
211
  if rbln_text_model_hidden_size is None:
211
212
  rbln_text_model_hidden_size = 768
212
- rbln_runtime_config.input_info.append(
213
- ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
214
- )
215
- rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
213
+ input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
214
+ input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
215
+
216
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
217
+
218
+ rbln_config = RBLNConfig(
219
+ rbln_cls=cls.__name__,
220
+ compile_cfgs=[rbln_compile_config],
221
+ rbln_kwargs=rbln_kwargs,
222
+ )
223
+
224
+ rbln_config.model_cfg.update(
225
+ {
226
+ "max_seq_len": rbln_max_seq_len,
227
+ "batch_size": rbln_batch_size,
228
+ "img_width": rbln_img_width,
229
+ "img_height": rbln_img_height,
230
+ "vae_scale_factor": rbln_vae_scale_factor,
231
+ }
232
+ )
216
233
 
217
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
218
234
  return rbln_config
219
235
 
220
236
  def forward(
@@ -32,7 +32,7 @@ from optimum.exporters import TasksManager
32
32
  from transformers import AutoConfig, AutoModel, PretrainedConfig
33
33
 
34
34
  from ...modeling_base import RBLNModel
35
- from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
35
+ from ...modeling_config import RBLNCompileConfig, RBLNConfig
36
36
 
37
37
 
38
38
  if TYPE_CHECKING:
@@ -130,8 +130,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
130
130
  auto_model_class = AutoModel # feature extraction
131
131
 
132
132
  def __post_init__(self, **kwargs):
133
- self.dtype = torch.float32
134
- self.in_features = self.rbln_config.meta.get("in_features", None)
133
+ super().__post_init__(**kwargs)
134
+ self.in_features = self.rbln_config.model_cfg.get("in_features", None)
135
135
  if self.in_features is not None:
136
136
 
137
137
  @dataclass
@@ -183,31 +183,31 @@ class RBLNUNet2DConditionModel(RBLNModel):
183
183
  cls,
184
184
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
185
185
  model_config: "PretrainedConfig",
186
- rbln_max_seq_len: Optional[int] = None,
187
- rbln_text_model_hidden_size: Optional[int] = None,
188
- rbln_batch_size: Optional[int] = None,
189
- rbln_in_features: Optional[int] = None,
190
- rbln_use_encode: Optional[bool] = None,
191
- rbln_img_width: Optional[int] = None,
192
- rbln_img_height: Optional[int] = None,
193
- rbln_vae_scale_factor: Optional[int] = None,
194
- rbln_is_controlnet: Optional[bool] = None,
186
+ rbln_kwargs: Dict[str, Any] = {},
195
187
  ) -> RBLNConfig:
196
- meta = {"type": "unet"}
197
- if rbln_batch_size is None:
198
- rbln_batch_size = 1
188
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
189
+ rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
190
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
191
+ rbln_in_features = rbln_kwargs.get("in_features", None)
192
+ rbln_use_encode = rbln_kwargs.get("use_encode", None)
193
+ rbln_img_width = rbln_kwargs.get("img_width", None)
194
+ rbln_img_height = rbln_kwargs.get("img_height", None)
195
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
196
+ rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
199
197
 
200
198
  if rbln_max_seq_len is None:
201
199
  rbln_max_seq_len = 77
202
-
203
- meta["rbln_use_encode"] = rbln_use_encode
200
+ if rbln_batch_size is None:
201
+ rbln_batch_size = 1
204
202
 
205
203
  if rbln_use_encode:
206
- # FIXME :: robust img shape getter
204
+ if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
205
+ raise ValueError(
206
+ "rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
207
+ )
207
208
  input_width = rbln_img_width // rbln_vae_scale_factor
208
209
  input_height = rbln_img_height // rbln_vae_scale_factor
209
210
  else:
210
- # FIXME :: model_config.sample_size can be tuple or list
211
211
  input_width, input_height = model_config.sample_size, model_config.sample_size
212
212
 
213
213
  input_info = [
@@ -232,6 +232,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
232
232
  "float32",
233
233
  ),
234
234
  ]
235
+
235
236
  if rbln_is_controlnet:
236
237
  if len(model_config.block_out_channels) > 0:
237
238
  input_info.extend(
@@ -304,24 +305,35 @@ class RBLNUNet2DConditionModel(RBLNModel):
304
305
  )
305
306
  )
306
307
 
307
- rbln_runtime_config = RBLNRuntimeConfig(
308
- input_info=input_info,
309
- batch_size=rbln_batch_size,
310
- )
308
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
311
309
 
312
310
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
313
- # In case of sdxl
314
311
  if rbln_text_model_hidden_size is None:
315
312
  rbln_text_model_hidden_size = 768
316
313
  if rbln_in_features is None:
317
314
  rbln_in_features = model_config.projection_class_embeddings_input_dim
318
- meta["in_features"] = rbln_in_features
319
- rbln_runtime_config.input_info.append(
315
+ rbln_compile_config.input_info.append(
320
316
  ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
321
317
  )
322
- rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
318
+ rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
319
+
320
+ rbln_config = RBLNConfig(
321
+ rbln_cls=cls.__name__,
322
+ compile_cfgs=[rbln_compile_config],
323
+ rbln_kwargs=rbln_kwargs,
324
+ )
325
+
326
+ rbln_config.model_cfg.update(
327
+ {
328
+ "max_seq_len": rbln_max_seq_len,
329
+ "batch_size": rbln_batch_size,
330
+ "use_encode": rbln_use_encode,
331
+ }
332
+ )
333
+
334
+ if rbln_in_features is not None:
335
+ rbln_config.model_cfg["in_features"] = rbln_in_features
323
336
 
324
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
325
337
  return rbln_config
326
338
 
327
339
  def forward(
@@ -37,6 +37,7 @@ from transformers import CLIPTextModel
37
37
 
38
38
  from ....modeling_base import RBLNBaseModel
39
39
  from ....transformers import RBLNCLIPTextModel
40
+ from ....utils.runtime_utils import ContextRblnConfig
40
41
  from ...models import RBLNAutoencoderKL, RBLNControlNetModel, RBLNUNet2DConditionModel
41
42
  from ...pipelines.controlnet.multicontrolnet import RBLNMultiControlNetModel
42
43
 
@@ -69,8 +70,13 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
69
70
  text_encoder = kwargs.pop("text_encoder", None)
70
71
  controlnet = kwargs.pop("controlnet", None)
71
72
  model_save_dir = kwargs.pop("model_save_dir", None)
73
+ rbln_config = kwargs.pop("rbln_config", None)
74
+ rbln_kwargs, _ = RBLNBaseModel.resolve_rbln_config(rbln_config, kwargs)
72
75
 
73
- rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
76
+ device = rbln_kwargs.get("device", None)
77
+ device_map = rbln_kwargs.get("device_map", None)
78
+ create_runtimes = rbln_kwargs.get("create_runtimes", None)
79
+ optimize_host_memory = rbln_kwargs.get("optimize_host_memory", None)
74
80
 
75
81
  kwargs_dict = {
76
82
  "pretrained_model_name_or_path": model_id,
@@ -98,13 +104,19 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
98
104
  }
99
105
  )
100
106
 
101
- model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
107
+ with ContextRblnConfig(
108
+ device=device,
109
+ device_map=device_map,
110
+ create_runtimes=create_runtimes,
111
+ optimze_host_mem=optimize_host_memory,
112
+ ):
113
+ model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
102
114
 
103
115
  if export is None or export is False:
104
116
  return model
105
117
 
106
118
  do_classifier_free_guidance = (
107
- rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
119
+ rbln_kwargs.pop("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
108
120
  )
109
121
 
110
122
  # compile model, create runtime
@@ -117,8 +129,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
117
129
  rbln_unet_sample_size=model.unet.config.sample_size,
118
130
  rbln_use_encode=False,
119
131
  rbln_vae_scale_factor=model.vae_scale_factor,
120
- **rbln_config_kwargs,
121
- **rbln_constructor_kwargs,
132
+ rbln_config={**rbln_kwargs},
122
133
  )
123
134
 
124
135
  if not isinstance(text_encoder, RBLNCLIPTextModel):
@@ -127,11 +138,10 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
127
138
  subfolder="text_encoder",
128
139
  export=True,
129
140
  model_save_dir=model_save_dir,
130
- **rbln_config_kwargs,
131
- **rbln_constructor_kwargs,
141
+ rbln_config={**rbln_kwargs},
132
142
  )
133
143
 
134
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
144
+ batch_size = rbln_kwargs.pop("batch_size", 1)
135
145
  unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
136
146
 
137
147
  if not isinstance(unet, RBLNUNet2DConditionModel):
@@ -145,8 +155,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
145
155
  rbln_use_encode=False,
146
156
  rbln_vae_scale_factor=model.vae_scale_factor,
147
157
  rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
148
- **rbln_config_kwargs,
149
- **rbln_constructor_kwargs,
158
+ rbln_config={**rbln_kwargs},
150
159
  )
151
160
 
152
161
  if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
@@ -162,8 +171,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
162
171
  model_save_dir=model_save_dir,
163
172
  rbln_batch_size=unet_batch_size,
164
173
  rbln_vae_scale_factor=model.vae_scale_factor,
165
- **rbln_config_kwargs,
166
- **rbln_constructor_kwargs,
174
+ rbln_config={**rbln_kwargs},
167
175
  )
168
176
  )
169
177
  controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
@@ -176,8 +184,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
176
184
  model_save_dir=model_save_dir,
177
185
  rbln_batch_size=unet_batch_size,
178
186
  rbln_vae_scale_factor=model.vae_scale_factor,
179
- **rbln_config_kwargs,
180
- **rbln_constructor_kwargs,
187
+ rbln_config={**rbln_kwargs},
181
188
  )
182
189
  controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
183
190
 
@@ -209,7 +216,7 @@ class RBLNStableDiffusionControlNetPipeline(StableDiffusionControlNetPipeline):
209
216
  model.save_config(model_save_dir)
210
217
 
211
218
  # use for CI to access each compiled model
212
- if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
219
+ if optimize_host_memory is False:
213
220
  model.compiled_models = [vae.compiled_models[0], text_encoder.compiled_models[0], unet.compiled_models[0]]
214
221
  if isinstance(controlnet, RBLNMultiControlNetModel):
215
222
  for c_model in controlnet.nets: