optimum-rbln 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +5 -1
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -61
  4. optimum/rbln/diffusers/models/controlnet.py +36 -56
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -153
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +7 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -190
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -191
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -192
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -110
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -115
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -122
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -125
  16. optimum/rbln/modeling_base.py +12 -5
  17. optimum/rbln/modeling_diffusers.py +400 -0
  18. optimum/rbln/transformers/__init__.py +2 -0
  19. optimum/rbln/transformers/cache_utils.py +5 -9
  20. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  21. optimum/rbln/transformers/models/__init__.py +80 -31
  22. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -22
  23. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  25. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +74 -16
  26. optimum/rbln/transformers/models/exaone/exaone_architecture.py +18 -9
  27. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -29
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  29. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  30. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  31. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  32. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +27 -8
  33. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +4 -29
  35. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  36. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  37. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  38. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  39. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  40. optimum/rbln/transformers/models/t5/modeling_t5.py +57 -4
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  42. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  43. optimum/rbln/utils/context.py +58 -0
  44. optimum/rbln/utils/decorator_utils.py +55 -0
  45. optimum/rbln/utils/import_utils.py +7 -0
  46. optimum/rbln/utils/runtime_utils.py +4 -4
  47. optimum/rbln/utils/timer_utils.py +2 -2
  48. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +8 -7
  49. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/RECORD +52 -48
  50. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +0 -0
  51. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  52. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -38,7 +38,7 @@ _import_structure = {
38
38
  "RBLNXLMRobertaForSequenceClassification",
39
39
  "RBLNRobertaForSequenceClassification",
40
40
  "RBLNRobertaForMaskedLM",
41
- "RBLNViTForImageClassification"
41
+ "RBLNViTForImageClassification",
42
42
  ],
43
43
  "modeling_base": [
44
44
  "RBLNBaseModel",
@@ -76,6 +76,7 @@ _import_structure = {
76
76
  "RBLNQwen2ForCausalLM",
77
77
  "RBLNWav2Vec2ForCTC",
78
78
  "RBLNLlamaForCausalLM",
79
+ "RBLNT5EncoderModel",
79
80
  "RBLNT5ForConditionalGeneration",
80
81
  "RBLNPhiForCausalLM",
81
82
  "RBLNLlavaNextForConditionalGeneration",
@@ -99,6 +100,7 @@ _import_structure = {
99
100
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
100
101
  ],
101
102
  "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
103
+ "modeling_diffusers": ["RBLNDiffusionMixin"],
102
104
  }
103
105
 
104
106
  if TYPE_CHECKING:
@@ -136,6 +138,7 @@ if TYPE_CHECKING:
136
138
  RBLNModelForSequenceClassification,
137
139
  )
138
140
  from .modeling_config import RBLNCompileConfig, RBLNConfig
141
+ from .modeling_diffusers import RBLNDiffusionMixin
139
142
  from .transformers import (
140
143
  BatchTextIteratorStreamer,
141
144
  RBLNAutoModel,
@@ -166,6 +169,7 @@ if TYPE_CHECKING:
166
169
  RBLNMistralForCausalLM,
167
170
  RBLNPhiForCausalLM,
168
171
  RBLNQwen2ForCausalLM,
172
+ RBLNT5EncoderModel,
169
173
  RBLNT5ForConditionalGeneration,
170
174
  RBLNWav2Vec2ForCTC,
171
175
  RBLNWhisperForConditionalGeneration,
@@ -1 +1 @@
1
- __version__ = '0.1.12'
1
+ __version__ = '0.1.13'
@@ -22,7 +22,6 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from pathlib import Path
26
25
  from typing import TYPE_CHECKING, Any, Dict, List, Union
27
26
 
28
27
  import rebel
@@ -30,11 +29,11 @@ import torch # noqa: I001
30
29
  from diffusers import AutoencoderKL
31
30
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
32
31
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
33
- from optimum.exporters import TasksManager
34
- from transformers import AutoConfig, AutoModel, PretrainedConfig
32
+ from transformers import PretrainedConfig
35
33
 
36
34
  from ...modeling_base import RBLNModel
37
35
  from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ...utils.context import override_auto_classes
38
37
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
38
 
40
39
 
@@ -63,8 +62,7 @@ class RBLNAutoencoderKL(RBLNModel):
63
62
  def __post_init__(self, **kwargs):
64
63
  super().__post_init__(**kwargs)
65
64
 
66
- self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
67
- if self.rbln_use_encode:
65
+ if self.rbln_config.model_cfg.get("img2img_pipeline"):
68
66
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
69
67
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
70
68
  else:
@@ -91,38 +89,15 @@ class RBLNAutoencoderKL(RBLNModel):
91
89
 
92
90
  return dec_compiled_model
93
91
 
94
- if rbln_config.model_cfg.get("use_encode", False):
92
+ if rbln_config.model_cfg.get("img2img_pipeline"):
95
93
  return compile_img2img()
96
94
  else:
97
95
  return compile_text2img()
98
96
 
99
97
  @classmethod
100
98
  def from_pretrained(cls, *args, **kwargs):
101
- def get_model_from_task(
102
- task: str,
103
- model_name_or_path: Union[str, Path],
104
- **kwargs,
105
- ):
106
- return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
107
-
108
- tasktmp = TasksManager.get_model_from_task
109
- configtmp = AutoConfig.from_pretrained
110
- modeltmp = AutoModel.from_pretrained
111
- TasksManager.get_model_from_task = get_model_from_task
112
-
113
- if kwargs.get("export", None):
114
- # This is an ad-hoc to workaround save null values of the config.
115
- # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
116
- # and diffusers model do not support loading by AutoConfig.
117
- AutoConfig.from_pretrained = lambda *args, **kwargs: None
118
- else:
119
- AutoConfig.from_pretrained = AutoencoderKL.load_config
120
-
121
- AutoModel.from_pretrained = AutoencoderKL.from_pretrained
122
- rt = super().from_pretrained(*args, **kwargs)
123
- AutoConfig.from_pretrained = configtmp
124
- AutoModel.from_pretrained = modeltmp
125
- TasksManager.get_model_from_task = tasktmp
99
+ with override_auto_classes(config_func=AutoencoderKL.load_config, model_func=AutoencoderKL.from_pretrained):
100
+ rt = super().from_pretrained(*args, **kwargs)
126
101
  return rt
127
102
 
128
103
  @classmethod
@@ -132,34 +107,39 @@ class RBLNAutoencoderKL(RBLNModel):
132
107
  model_config: "PretrainedConfig",
133
108
  rbln_kwargs: Dict[str, Any] = {},
134
109
  ) -> RBLNConfig:
135
- rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
136
- rbln_img_width = rbln_kwargs.get("img_width", None)
137
- rbln_img_height = rbln_kwargs.get("img_height", None)
138
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
139
- rbln_use_encode = rbln_kwargs.get("use_encode", None)
140
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
110
+ rbln_batch_size = rbln_kwargs.get("batch_size")
111
+ sample_size = rbln_kwargs.get("sample_size")
141
112
 
142
113
  if rbln_batch_size is None:
143
114
  rbln_batch_size = 1
144
115
 
145
- model_cfg = {}
116
+ if sample_size is None:
117
+ sample_size = model_config.sample_size
118
+
119
+ if isinstance(sample_size, int):
120
+ sample_size = (sample_size, sample_size)
121
+
122
+ if hasattr(model_config, "block_out_channels"):
123
+ vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
124
+ else:
125
+ # vae image processor default value 8 (int)
126
+ vae_scale_factor = 8
146
127
 
147
- if rbln_use_encode:
148
- model_cfg["img_width"] = rbln_img_width
149
- model_cfg["img_height"] = rbln_img_height
128
+ dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
129
+ enc_shape = (sample_size[0], sample_size[1])
150
130
 
131
+ if rbln_kwargs["img2img_pipeline"]:
151
132
  vae_enc_input_info = [
152
- ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
133
+ (
134
+ "x",
135
+ [rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
136
+ "float32",
137
+ )
153
138
  ]
154
139
  vae_dec_input_info = [
155
140
  (
156
141
  "z",
157
- [
158
- rbln_batch_size,
159
- model_config.latent_channels,
160
- rbln_img_height // rbln_vae_scale_factor,
161
- rbln_img_width // rbln_vae_scale_factor,
162
- ],
142
+ [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
163
143
  "float32",
164
144
  )
165
145
  ]
@@ -173,33 +153,22 @@ class RBLNAutoencoderKL(RBLNModel):
173
153
  compile_cfgs=compile_cfgs,
174
154
  rbln_kwargs=rbln_kwargs,
175
155
  )
176
- rbln_config.model_cfg.update(model_cfg)
177
156
  return rbln_config
178
157
 
179
- if rbln_unet_sample_size is None:
180
- rbln_unet_sample_size = 64
181
-
182
- model_cfg["unet_sample_size"] = rbln_unet_sample_size
183
158
  vae_config = RBLNCompileConfig(
184
159
  input_info=[
185
160
  (
186
161
  "z",
187
- [
188
- rbln_batch_size,
189
- model_config.latent_channels,
190
- rbln_unet_sample_size,
191
- rbln_unet_sample_size,
192
- ],
162
+ [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
193
163
  "float32",
194
164
  )
195
- ],
165
+ ]
196
166
  )
197
167
  rbln_config = RBLNConfig(
198
168
  rbln_cls=cls.__name__,
199
169
  compile_cfgs=[vae_config],
200
170
  rbln_kwargs=rbln_kwargs,
201
171
  )
202
- rbln_config.model_cfg.update(model_cfg)
203
172
  return rbln_config
204
173
 
205
174
  @classmethod
@@ -22,16 +22,15 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from pathlib import Path
26
25
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
26
 
28
27
  import torch
29
28
  from diffusers import ControlNetModel
30
- from optimum.exporters import TasksManager
31
- from transformers import AutoConfig, AutoModel, PretrainedConfig
29
+ from transformers import PretrainedConfig
32
30
 
33
31
  from ...modeling_base import RBLNModel
34
32
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
33
+ from ...utils.context import override_auto_classes
35
34
 
36
35
 
37
36
  if TYPE_CHECKING:
@@ -113,23 +112,11 @@ class RBLNControlNetModel(RBLNModel):
113
112
 
114
113
  @classmethod
115
114
  def from_pretrained(cls, *args, **kwargs):
116
- def get_model_from_task(
117
- task: str,
118
- model_name_or_path: Union[str, Path],
119
- **kwargs,
115
+ with override_auto_classes(
116
+ config_func=ControlNetModel.load_config,
117
+ model_func=ControlNetModel.from_pretrained,
120
118
  ):
121
- return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
122
-
123
- tasktmp = TasksManager.get_model_from_task
124
- configtmp = AutoConfig.from_pretrained
125
- modeltmp = AutoModel.from_pretrained
126
- TasksManager.get_model_from_task = get_model_from_task
127
- AutoConfig.from_pretrained = ControlNetModel.load_config
128
- AutoModel.from_pretrained = ControlNetModel.from_pretrained
129
- rt = super().from_pretrained(*args, **kwargs)
130
- AutoConfig.from_pretrained = configtmp
131
- AutoModel.from_pretrained = modeltmp
132
- TasksManager.get_model_from_task = tasktmp
119
+ rt = super().from_pretrained(*args, **kwargs)
133
120
  return rt
134
121
 
135
122
  @classmethod
@@ -151,33 +138,35 @@ class RBLNControlNetModel(RBLNModel):
151
138
  model_config: "PretrainedConfig",
152
139
  rbln_kwargs: Dict[str, Any] = {},
153
140
  ) -> RBLNConfig:
154
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
155
- rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
156
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
157
- rbln_img_width = rbln_kwargs.get("img_width", None)
158
- rbln_img_height = rbln_kwargs.get("img_height", None)
159
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
141
+ batch_size = rbln_kwargs.get("batch_size")
142
+ max_seq_len = rbln_kwargs.get("max_seq_len")
143
+ unet_sample_size = rbln_kwargs.get("unet_sample_size")
144
+ vae_sample_size = rbln_kwargs.get("vae_sample_size")
160
145
 
161
- if rbln_batch_size is None:
162
- rbln_batch_size = 1
146
+ if batch_size is None:
147
+ batch_size = 1
163
148
 
164
- if rbln_max_seq_len is None:
165
- rbln_max_seq_len = 77
149
+ if unet_sample_size is None:
150
+ raise ValueError(
151
+ "`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
152
+ )
166
153
 
167
- if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
168
- raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
154
+ if vae_sample_size is None:
155
+ raise ValueError(
156
+ "`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
157
+ )
169
158
 
170
- input_width = rbln_img_width // rbln_vae_scale_factor
171
- input_height = rbln_img_height // rbln_vae_scale_factor
159
+ if max_seq_len is None:
160
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
172
161
 
173
162
  input_info = [
174
163
  (
175
164
  "sample",
176
165
  [
177
- rbln_batch_size,
166
+ batch_size,
178
167
  model_config.in_channels,
179
- input_height,
180
- input_width,
168
+ unet_sample_size[0],
169
+ unet_sample_size[1],
181
170
  ],
182
171
  "float32",
183
172
  ),
@@ -189,23 +178,24 @@ class RBLNControlNetModel(RBLNModel):
189
178
  input_info.append(
190
179
  (
191
180
  "encoder_hidden_states",
192
- [
193
- rbln_batch_size,
194
- rbln_max_seq_len,
195
- model_config.cross_attention_dim,
196
- ],
181
+ [batch_size, max_seq_len, model_config.cross_attention_dim],
197
182
  "float32",
198
183
  )
199
184
  )
200
185
 
201
- input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
186
+ input_info.append(
187
+ (
188
+ "controlnet_cond",
189
+ [batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
190
+ "float32",
191
+ )
192
+ )
202
193
  input_info.append(("conditioning_scale", [], "float32"))
203
194
 
204
195
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
205
- if rbln_text_model_hidden_size is None:
206
- rbln_text_model_hidden_size = 768
207
- input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
208
- input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
196
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
197
+ input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
198
+ input_info.append(("time_ids", [batch_size, 6], "float32"))
209
199
 
210
200
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
211
201
 
@@ -215,16 +205,6 @@ class RBLNControlNetModel(RBLNModel):
215
205
  rbln_kwargs=rbln_kwargs,
216
206
  )
217
207
 
218
- rbln_config.model_cfg.update(
219
- {
220
- "max_seq_len": rbln_max_seq_len,
221
- "batch_size": rbln_batch_size,
222
- "img_width": rbln_img_width,
223
- "img_height": rbln_img_height,
224
- "vae_scale_factor": rbln_vae_scale_factor,
225
- }
226
- )
227
-
228
208
  return rbln_config
229
209
 
230
210
  def forward(
@@ -23,16 +23,15 @@
23
23
 
24
24
  import logging
25
25
  from dataclasses import dataclass
26
- from pathlib import Path
27
26
  from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
28
27
 
29
28
  import torch
30
29
  from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
31
- from optimum.exporters import TasksManager
32
- from transformers import AutoConfig, AutoModel, PretrainedConfig
30
+ from transformers import PretrainedConfig
33
31
 
34
32
  from ...modeling_base import RBLNModel
35
33
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
34
+ from ...utils.context import override_auto_classes
36
35
 
37
36
 
38
37
  if TYPE_CHECKING:
@@ -143,29 +142,11 @@ class RBLNUNet2DConditionModel(RBLNModel):
143
142
 
144
143
  @classmethod
145
144
  def from_pretrained(cls, *args, **kwargs):
146
- def get_model_from_task(
147
- task: str,
148
- model_name_or_path: Union[str, Path],
149
- **kwargs,
145
+ with override_auto_classes(
146
+ config_func=UNet2DConditionModel.load_config,
147
+ model_func=UNet2DConditionModel.from_pretrained,
150
148
  ):
151
- return UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
152
-
153
- tasktmp = TasksManager.get_model_from_task
154
- configtmp = AutoConfig.from_pretrained
155
- modeltmp = AutoModel.from_pretrained
156
- TasksManager.get_model_from_task = get_model_from_task
157
- if kwargs.get("export", None):
158
- # This is an ad-hoc to workaround save null values of the config.
159
- # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
160
- # and diffusers model do not support loading by AutoConfig.
161
- AutoConfig.from_pretrained = lambda *args, **kwargs: None
162
- else:
163
- AutoConfig.from_pretrained = UNet2DConditionModel.load_config
164
- AutoModel.from_pretrained = UNet2DConditionModel.from_pretrained
165
- rt = super().from_pretrained(*args, **kwargs)
166
- AutoConfig.from_pretrained = configtmp
167
- AutoModel.from_pretrained = modeltmp
168
- TasksManager.get_model_from_task = tasktmp
149
+ rt = super().from_pretrained(*args, **kwargs)
169
150
  return rt
170
151
 
171
152
  @classmethod
@@ -182,137 +163,68 @@ class RBLNUNet2DConditionModel(RBLNModel):
182
163
  model_config: "PretrainedConfig",
183
164
  rbln_kwargs: Dict[str, Any] = {},
184
165
  ) -> RBLNConfig:
185
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
186
- rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
187
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
188
- rbln_in_features = rbln_kwargs.get("in_features", None)
189
- rbln_use_encode = rbln_kwargs.get("use_encode", None)
190
- rbln_img_width = rbln_kwargs.get("img_width", None)
191
- rbln_img_height = rbln_kwargs.get("img_height", None)
192
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
193
- rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
194
-
195
- if rbln_max_seq_len is None:
196
- rbln_max_seq_len = 77
197
- if rbln_batch_size is None:
198
- rbln_batch_size = 1
199
-
200
- if rbln_use_encode:
201
- if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
202
- raise ValueError(
203
- "rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
204
- )
205
- input_width = rbln_img_width // rbln_vae_scale_factor
206
- input_height = rbln_img_height // rbln_vae_scale_factor
207
- else:
208
- input_width, input_height = model_config.sample_size, model_config.sample_size
166
+ batch_size = rbln_kwargs.get("batch_size")
167
+ max_seq_len = rbln_kwargs.get("max_seq_len")
168
+ sample_size = rbln_kwargs.get("sample_size")
169
+ is_controlnet = rbln_kwargs.get("is_controlnet")
170
+ rbln_in_features = None
171
+
172
+ if batch_size is None:
173
+ batch_size = 1
174
+
175
+ if sample_size is None:
176
+ sample_size = model_config.sample_size
177
+
178
+ if isinstance(sample_size, int):
179
+ sample_size = (sample_size, sample_size)
180
+
181
+ if max_seq_len is None:
182
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
209
183
 
210
184
  input_info = [
211
- (
212
- "sample",
213
- [
214
- rbln_batch_size,
215
- model_config.in_channels,
216
- input_height,
217
- input_width,
218
- ],
219
- "float32",
220
- ),
185
+ ("sample", [batch_size, model_config.in_channels, sample_size[0], sample_size[1]], "float32"),
221
186
  ("timestep", [], "float32"),
222
- (
223
- "encoder_hidden_states",
224
- [
225
- rbln_batch_size,
226
- rbln_max_seq_len,
227
- model_config.cross_attention_dim,
228
- ],
229
- "float32",
230
- ),
187
+ ("encoder_hidden_states", [batch_size, max_seq_len, model_config.cross_attention_dim], "float32"),
231
188
  ]
232
189
 
233
- if rbln_is_controlnet:
234
- if len(model_config.block_out_channels) > 0:
235
- input_info.extend(
236
- [
237
- (
238
- f"down_block_additional_residuals_{i}",
239
- [rbln_batch_size, model_config.block_out_channels[0], input_height, input_width],
240
- "float32",
241
- )
242
- for i in range(3)
243
- ]
244
- )
245
- if len(model_config.block_out_channels) > 1:
246
- input_info.append(
247
- (
248
- "down_block_additional_residuals_3",
249
- [rbln_batch_size, model_config.block_out_channels[0], input_height // 2, input_width // 2],
250
- "float32",
251
- )
252
- )
253
- input_info.extend(
254
- [
255
- (
256
- f"down_block_additional_residuals_{i}",
257
- [rbln_batch_size, model_config.block_out_channels[1], input_height // 2, input_width // 2],
258
- "float32",
259
- )
260
- for i in range(4, 6)
261
- ]
262
- )
263
- if len(model_config.block_out_channels) > 2:
264
- input_info.append(
265
- (
266
- f"down_block_additional_residuals_{6}",
267
- [rbln_batch_size, model_config.block_out_channels[1], input_height // 4, input_width // 4],
268
- "float32",
269
- )
270
- )
271
- input_info.extend(
272
- [
273
- (
274
- f"down_block_additional_residuals_{i}",
275
- [rbln_batch_size, model_config.block_out_channels[2], input_height // 4, input_width // 4],
276
- "float32",
277
- )
278
- for i in range(7, 9)
279
- ]
280
- )
281
- if len(model_config.block_out_channels) > 3:
282
- input_info.extend(
283
- [
284
- (
285
- f"down_block_additional_residuals_{i}",
286
- [rbln_batch_size, model_config.block_out_channels[3], input_height // 8, input_width // 8],
287
- "float32",
288
- )
289
- for i in range(9, 12)
290
- ]
291
- )
292
- input_info.append(
293
- (
294
- "mid_block_additional_residual",
295
- [
296
- rbln_batch_size,
297
- model_config.block_out_channels[-1],
298
- input_height // 2 ** (len(model_config.block_out_channels) - 1),
299
- input_width // 2 ** (len(model_config.block_out_channels) - 1),
300
- ],
301
- "float32",
302
- )
303
- )
190
+ if is_controlnet:
191
+ # down block addtional residuals
192
+ first_shape = [batch_size, model_config.block_out_channels[0], sample_size[0], sample_size[1]]
193
+ height, width = sample_size[0], sample_size[1]
194
+ input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
195
+ name_idx = 1
196
+ for idx, _ in enumerate(model_config.down_block_types):
197
+ shape = [batch_size, model_config.block_out_channels[idx], height, width]
198
+ for _ in range(model_config.layers_per_block):
199
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
200
+ name_idx += 1
201
+ if idx != len(model_config.down_block_types) - 1:
202
+ height = height // 2
203
+ width = width // 2
204
+ shape = [batch_size, model_config.block_out_channels[idx], height, width]
205
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
206
+ name_idx += 1
207
+
208
+ # mid block addtional residual
209
+ num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
210
+ out_channels = model_config.block_out_channels[-1]
211
+ shape = [
212
+ batch_size,
213
+ out_channels,
214
+ sample_size[0] // 2**num_cross_attn_blocks,
215
+ sample_size[1] // 2**num_cross_attn_blocks,
216
+ ]
217
+ input_info.append(("mid_block_additional_residual", shape, "float32"))
304
218
 
305
219
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
306
220
 
307
221
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
308
- if rbln_text_model_hidden_size is None:
309
- rbln_text_model_hidden_size = 768
310
- if rbln_in_features is None:
311
- rbln_in_features = model_config.projection_class_embeddings_input_dim
222
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
223
+ rbln_in_features = model_config.projection_class_embeddings_input_dim
312
224
  rbln_compile_config.input_info.append(
313
- ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
225
+ ("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32")
314
226
  )
315
- rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
227
+ rbln_compile_config.input_info.append(("time_ids", [batch_size, 6], "float32"))
316
228
 
317
229
  rbln_config = RBLNConfig(
318
230
  rbln_cls=cls.__name__,
@@ -320,14 +232,6 @@ class RBLNUNet2DConditionModel(RBLNModel):
320
232
  rbln_kwargs=rbln_kwargs,
321
233
  )
322
234
 
323
- rbln_config.model_cfg.update(
324
- {
325
- "max_seq_len": rbln_max_seq_len,
326
- "batch_size": rbln_batch_size,
327
- "use_encode": rbln_use_encode,
328
- }
329
- )
330
-
331
235
  if rbln_in_features is not None:
332
236
  rbln_config.model_cfg["in_features"] = rbln_in_features
333
237