optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -35,11 +35,10 @@ _import_structure = {
35
35
  "RBLNBertForQuestionAnswering",
36
36
  "RBLNDistilBertForQuestionAnswering",
37
37
  "RBLNResNetForImageClassification",
38
- "RBLNT5ForConditionalGeneration",
39
- "RBLNBartForConditionalGeneration",
40
38
  "RBLNXLMRobertaForSequenceClassification",
41
39
  "RBLNRobertaForSequenceClassification",
42
40
  "RBLNRobertaForMaskedLM",
41
+ "RBLNViTForImageClassification",
43
42
  ],
44
43
  "modeling_base": [
45
44
  "RBLNBaseModel",
@@ -50,9 +49,6 @@ _import_structure = {
50
49
  "RBLNModelForSequenceClassification",
51
50
  "RBLNModelForMaskedLM",
52
51
  ],
53
- "modeling_seq2seq": [
54
- "RBLNModelForSeq2SeqLM",
55
- ],
56
52
  "transformers": [
57
53
  "BatchTextIteratorStreamer",
58
54
  "RBLNAutoModel",
@@ -67,16 +63,21 @@ _import_structure = {
67
63
  "RBLNAutoModelForSequenceClassification",
68
64
  "RBLNAutoModelForSpeechSeq2Seq",
69
65
  "RBLNAutoModelForVision2Seq",
66
+ "RBLNBartForConditionalGeneration",
70
67
  "RBLNBartModel",
71
68
  "RBLNBertModel",
72
69
  "RBLNCLIPTextModel",
73
70
  "RBLNCLIPTextModelWithProjection",
74
71
  "RBLNCLIPVisionModel",
75
72
  "RBLNDPTForDepthEstimation",
73
+ "RBLNExaoneForCausalLM",
76
74
  "RBLNGemmaForCausalLM",
77
75
  "RBLNGPT2LMHeadModel",
76
+ "RBLNQwen2ForCausalLM",
78
77
  "RBLNWav2Vec2ForCTC",
79
78
  "RBLNLlamaForCausalLM",
79
+ "RBLNT5EncoderModel",
80
+ "RBLNT5ForConditionalGeneration",
80
81
  "RBLNPhiForCausalLM",
81
82
  "RBLNLlavaNextForConditionalGeneration",
82
83
  "RBLNMidmLMHeadModel",
@@ -99,6 +100,7 @@ _import_structure = {
99
100
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
100
101
  ],
101
102
  "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
103
+ "modeling_diffusers": ["RBLNDiffusionMixin"],
102
104
  }
103
105
 
104
106
  if TYPE_CHECKING:
@@ -118,12 +120,12 @@ if TYPE_CHECKING:
118
120
  )
119
121
  from .modeling_alias import (
120
122
  RBLNASTForAudioClassification,
121
- RBLNBartForConditionalGeneration,
122
123
  RBLNBertForQuestionAnswering,
123
124
  RBLNResNetForImageClassification,
124
125
  RBLNRobertaForMaskedLM,
125
126
  RBLNRobertaForSequenceClassification,
126
127
  RBLNT5ForConditionalGeneration,
128
+ RBLNViTForImageClassification,
127
129
  RBLNXLMRobertaForSequenceClassification,
128
130
  )
129
131
  from .modeling_base import (
@@ -136,7 +138,7 @@ if TYPE_CHECKING:
136
138
  RBLNModelForSequenceClassification,
137
139
  )
138
140
  from .modeling_config import RBLNCompileConfig, RBLNConfig
139
- from .modeling_seq2seq import RBLNModelForSeq2SeqLM
141
+ from .modeling_diffusers import RBLNDiffusionMixin
140
142
  from .transformers import (
141
143
  BatchTextIteratorStreamer,
142
144
  RBLNAutoModel,
@@ -151,12 +153,14 @@ if TYPE_CHECKING:
151
153
  RBLNAutoModelForSequenceClassification,
152
154
  RBLNAutoModelForSpeechSeq2Seq,
153
155
  RBLNAutoModelForVision2Seq,
156
+ RBLNBartForConditionalGeneration,
154
157
  RBLNBartModel,
155
158
  RBLNBertModel,
156
159
  RBLNCLIPTextModel,
157
160
  RBLNCLIPTextModelWithProjection,
158
161
  RBLNCLIPVisionModel,
159
162
  RBLNDPTForDepthEstimation,
163
+ RBLNExaoneForCausalLM,
160
164
  RBLNGemmaForCausalLM,
161
165
  RBLNGPT2LMHeadModel,
162
166
  RBLNLlamaForCausalLM,
@@ -164,6 +168,9 @@ if TYPE_CHECKING:
164
168
  RBLNMidmLMHeadModel,
165
169
  RBLNMistralForCausalLM,
166
170
  RBLNPhiForCausalLM,
171
+ RBLNQwen2ForCausalLM,
172
+ RBLNT5EncoderModel,
173
+ RBLNT5ForConditionalGeneration,
167
174
  RBLNWav2Vec2ForCTC,
168
175
  RBLNWhisperForConditionalGeneration,
169
176
  RBLNXLMRobertaModel,
@@ -1 +1 @@
1
- __version__ = '0.1.11'
1
+ __version__ = '0.1.13'
@@ -22,7 +22,6 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from pathlib import Path
26
25
  from typing import TYPE_CHECKING, Any, Dict, List, Union
27
26
 
28
27
  import rebel
@@ -30,11 +29,11 @@ import torch # noqa: I001
30
29
  from diffusers import AutoencoderKL
31
30
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
32
31
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
33
- from optimum.exporters import TasksManager
34
- from transformers import AutoConfig, AutoModel, PretrainedConfig
32
+ from transformers import PretrainedConfig
35
33
 
36
34
  from ...modeling_base import RBLNModel
37
35
  from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
36
+ from ...utils.context import override_auto_classes
38
37
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
38
 
40
39
 
@@ -58,15 +57,12 @@ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
58
57
 
59
58
 
60
59
  class RBLNAutoencoderKL(RBLNModel):
61
- model_type = "rbln_model"
62
60
  config_name = "config.json"
63
- auto_model_class = AutoModel # feature extraction
64
61
 
65
62
  def __post_init__(self, **kwargs):
66
63
  super().__post_init__(**kwargs)
67
64
 
68
- self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
69
- if self.rbln_use_encode:
65
+ if self.rbln_config.model_cfg.get("img2img_pipeline"):
70
66
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
71
67
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
72
68
  else:
@@ -93,38 +89,15 @@ class RBLNAutoencoderKL(RBLNModel):
93
89
 
94
90
  return dec_compiled_model
95
91
 
96
- if rbln_config.model_cfg.get("use_encode", False):
92
+ if rbln_config.model_cfg.get("img2img_pipeline"):
97
93
  return compile_img2img()
98
94
  else:
99
95
  return compile_text2img()
100
96
 
101
97
  @classmethod
102
98
  def from_pretrained(cls, *args, **kwargs):
103
- def get_model_from_task(
104
- task: str,
105
- model_name_or_path: Union[str, Path],
106
- **kwargs,
107
- ):
108
- return AutoencoderKL.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
109
-
110
- tasktmp = TasksManager.get_model_from_task
111
- configtmp = AutoConfig.from_pretrained
112
- modeltmp = AutoModel.from_pretrained
113
- TasksManager.get_model_from_task = get_model_from_task
114
-
115
- if kwargs.get("export", None):
116
- # This is an ad-hoc to workaround save null values of the config.
117
- # if export, pure optimum(not optimum-rbln) loads config using AutoConfig
118
- # and diffusers model do not support loading by AutoConfig.
119
- AutoConfig.from_pretrained = lambda *args, **kwargs: None
120
- else:
121
- AutoConfig.from_pretrained = AutoencoderKL.load_config
122
-
123
- AutoModel.from_pretrained = AutoencoderKL.from_pretrained
124
- rt = super().from_pretrained(*args, **kwargs)
125
- AutoConfig.from_pretrained = configtmp
126
- AutoModel.from_pretrained = modeltmp
127
- TasksManager.get_model_from_task = tasktmp
99
+ with override_auto_classes(config_func=AutoencoderKL.load_config, model_func=AutoencoderKL.from_pretrained):
100
+ rt = super().from_pretrained(*args, **kwargs)
128
101
  return rt
129
102
 
130
103
  @classmethod
@@ -134,34 +107,39 @@ class RBLNAutoencoderKL(RBLNModel):
134
107
  model_config: "PretrainedConfig",
135
108
  rbln_kwargs: Dict[str, Any] = {},
136
109
  ) -> RBLNConfig:
137
- rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
138
- rbln_img_width = rbln_kwargs.get("img_width", None)
139
- rbln_img_height = rbln_kwargs.get("img_height", None)
140
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
141
- rbln_use_encode = rbln_kwargs.get("use_encode", None)
142
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
110
+ rbln_batch_size = rbln_kwargs.get("batch_size")
111
+ sample_size = rbln_kwargs.get("sample_size")
143
112
 
144
113
  if rbln_batch_size is None:
145
114
  rbln_batch_size = 1
146
115
 
147
- model_cfg = {}
116
+ if sample_size is None:
117
+ sample_size = model_config.sample_size
118
+
119
+ if isinstance(sample_size, int):
120
+ sample_size = (sample_size, sample_size)
121
+
122
+ if hasattr(model_config, "block_out_channels"):
123
+ vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
124
+ else:
125
+ # vae image processor default value 8 (int)
126
+ vae_scale_factor = 8
148
127
 
149
- if rbln_use_encode:
150
- model_cfg["img_width"] = rbln_img_width
151
- model_cfg["img_height"] = rbln_img_height
128
+ dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
129
+ enc_shape = (sample_size[0], sample_size[1])
152
130
 
131
+ if rbln_kwargs["img2img_pipeline"]:
153
132
  vae_enc_input_info = [
154
- ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
133
+ (
134
+ "x",
135
+ [rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
136
+ "float32",
137
+ )
155
138
  ]
156
139
  vae_dec_input_info = [
157
140
  (
158
141
  "z",
159
- [
160
- rbln_batch_size,
161
- model_config.latent_channels,
162
- rbln_img_height // rbln_vae_scale_factor,
163
- rbln_img_width // rbln_vae_scale_factor,
164
- ],
142
+ [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
165
143
  "float32",
166
144
  )
167
145
  ]
@@ -175,33 +153,22 @@ class RBLNAutoencoderKL(RBLNModel):
175
153
  compile_cfgs=compile_cfgs,
176
154
  rbln_kwargs=rbln_kwargs,
177
155
  )
178
- rbln_config.model_cfg.update(model_cfg)
179
156
  return rbln_config
180
157
 
181
- if rbln_unet_sample_size is None:
182
- rbln_unet_sample_size = 64
183
-
184
- model_cfg["unet_sample_size"] = rbln_unet_sample_size
185
158
  vae_config = RBLNCompileConfig(
186
159
  input_info=[
187
160
  (
188
161
  "z",
189
- [
190
- rbln_batch_size,
191
- model_config.latent_channels,
192
- rbln_unet_sample_size,
193
- rbln_unet_sample_size,
194
- ],
162
+ [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
195
163
  "float32",
196
164
  )
197
- ],
165
+ ]
198
166
  )
199
167
  rbln_config = RBLNConfig(
200
168
  rbln_cls=cls.__name__,
201
169
  compile_cfgs=[vae_config],
202
170
  rbln_kwargs=rbln_kwargs,
203
171
  )
204
- rbln_config.model_cfg.update(model_cfg)
205
172
  return rbln_config
206
173
 
207
174
  @classmethod
@@ -22,16 +22,15 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from pathlib import Path
26
25
  from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
26
 
28
27
  import torch
29
28
  from diffusers import ControlNetModel
30
- from optimum.exporters import TasksManager
31
- from transformers import AutoConfig, AutoModel, PretrainedConfig
29
+ from transformers import PretrainedConfig
32
30
 
33
31
  from ...modeling_base import RBLNModel
34
32
  from ...modeling_config import RBLNCompileConfig, RBLNConfig
33
+ from ...utils.context import override_auto_classes
35
34
 
36
35
 
37
36
  if TYPE_CHECKING:
@@ -105,9 +104,6 @@ class _ControlNetModel_Cross_Attention(torch.nn.Module):
105
104
 
106
105
 
107
106
  class RBLNControlNetModel(RBLNModel):
108
- model_type = "rbln_model"
109
- auto_model_class = AutoModel # feature extraction
110
-
111
107
  def __post_init__(self, **kwargs):
112
108
  super().__post_init__(**kwargs)
113
109
  self.use_encoder_hidden_states = any(
@@ -116,26 +112,11 @@ class RBLNControlNetModel(RBLNModel):
116
112
 
117
113
  @classmethod
118
114
  def from_pretrained(cls, *args, **kwargs):
119
- if "subfolder" in kwargs:
120
- del kwargs["subfolder"]
121
-
122
- def get_model_from_task(
123
- task: str,
124
- model_name_or_path: Union[str, Path],
125
- **kwargs,
115
+ with override_auto_classes(
116
+ config_func=ControlNetModel.load_config,
117
+ model_func=ControlNetModel.from_pretrained,
126
118
  ):
127
- return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
128
-
129
- tasktmp = TasksManager.get_model_from_task
130
- configtmp = AutoConfig.from_pretrained
131
- modeltmp = AutoModel.from_pretrained
132
- TasksManager.get_model_from_task = get_model_from_task
133
- AutoConfig.from_pretrained = ControlNetModel.load_config
134
- AutoModel.from_pretrained = ControlNetModel.from_pretrained
135
- rt = super().from_pretrained(*args, **kwargs)
136
- AutoConfig.from_pretrained = configtmp
137
- AutoModel.from_pretrained = modeltmp
138
- TasksManager.get_model_from_task = tasktmp
119
+ rt = super().from_pretrained(*args, **kwargs)
139
120
  return rt
140
121
 
141
122
  @classmethod
@@ -157,33 +138,35 @@ class RBLNControlNetModel(RBLNModel):
157
138
  model_config: "PretrainedConfig",
158
139
  rbln_kwargs: Dict[str, Any] = {},
159
140
  ) -> RBLNConfig:
160
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
161
- rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
162
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
163
- rbln_img_width = rbln_kwargs.get("img_width", None)
164
- rbln_img_height = rbln_kwargs.get("img_height", None)
165
- rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
141
+ batch_size = rbln_kwargs.get("batch_size")
142
+ max_seq_len = rbln_kwargs.get("max_seq_len")
143
+ unet_sample_size = rbln_kwargs.get("unet_sample_size")
144
+ vae_sample_size = rbln_kwargs.get("vae_sample_size")
166
145
 
167
- if rbln_batch_size is None:
168
- rbln_batch_size = 1
146
+ if batch_size is None:
147
+ batch_size = 1
169
148
 
170
- if rbln_max_seq_len is None:
171
- rbln_max_seq_len = 77
149
+ if unet_sample_size is None:
150
+ raise ValueError(
151
+ "`rbln_unet_sample_size` (latent height, widht) must be specified (ex. unet's sample_size)"
152
+ )
172
153
 
173
- if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
174
- raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
154
+ if vae_sample_size is None:
155
+ raise ValueError(
156
+ "`rbln_vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)"
157
+ )
175
158
 
176
- input_width = rbln_img_width // rbln_vae_scale_factor
177
- input_height = rbln_img_height // rbln_vae_scale_factor
159
+ if max_seq_len is None:
160
+ raise ValueError("`rbln_max_seq_len` (ex. text_encoder's max_position_embeddings )must be specified")
178
161
 
179
162
  input_info = [
180
163
  (
181
164
  "sample",
182
165
  [
183
- rbln_batch_size,
166
+ batch_size,
184
167
  model_config.in_channels,
185
- input_height,
186
- input_width,
168
+ unet_sample_size[0],
169
+ unet_sample_size[1],
187
170
  ],
188
171
  "float32",
189
172
  ),
@@ -195,23 +178,24 @@ class RBLNControlNetModel(RBLNModel):
195
178
  input_info.append(
196
179
  (
197
180
  "encoder_hidden_states",
198
- [
199
- rbln_batch_size,
200
- rbln_max_seq_len,
201
- model_config.cross_attention_dim,
202
- ],
181
+ [batch_size, max_seq_len, model_config.cross_attention_dim],
203
182
  "float32",
204
183
  )
205
184
  )
206
185
 
207
- input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
186
+ input_info.append(
187
+ (
188
+ "controlnet_cond",
189
+ [batch_size, 3, vae_sample_size[0], vae_sample_size[1]],
190
+ "float32",
191
+ )
192
+ )
208
193
  input_info.append(("conditioning_scale", [], "float32"))
209
194
 
210
195
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
211
- if rbln_text_model_hidden_size is None:
212
- rbln_text_model_hidden_size = 768
213
- input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
214
- input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
196
+ rbln_text_model_hidden_size = rbln_kwargs["text_model_hidden_size"]
197
+ input_info.append(("text_embeds", [batch_size, rbln_text_model_hidden_size], "float32"))
198
+ input_info.append(("time_ids", [batch_size, 6], "float32"))
215
199
 
216
200
  rbln_compile_config = RBLNCompileConfig(input_info=input_info)
217
201
 
@@ -221,16 +205,6 @@ class RBLNControlNetModel(RBLNModel):
221
205
  rbln_kwargs=rbln_kwargs,
222
206
  )
223
207
 
224
- rbln_config.model_cfg.update(
225
- {
226
- "max_seq_len": rbln_max_seq_len,
227
- "batch_size": rbln_batch_size,
228
- "img_width": rbln_img_width,
229
- "img_height": rbln_img_height,
230
- "vae_scale_factor": rbln_vae_scale_factor,
231
- }
232
- )
233
-
234
208
  return rbln_config
235
209
 
236
210
  def forward(