optimum-rbln 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. optimum/rbln/__init__.py +40 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +39 -32
  4. optimum/rbln/diffusers/models/controlnet.py +60 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +43 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  15. optimum/rbln/modeling_alias.py +8 -4
  16. optimum/rbln/modeling_base.py +512 -238
  17. optimum/rbln/modeling_config.py +152 -77
  18. optimum/rbln/modeling_seq2seq.py +166 -77
  19. optimum/rbln/transformers/__init__.py +37 -1
  20. optimum/rbln/transformers/models/__init__.py +21 -1
  21. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  22. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  23. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  24. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  25. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  26. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  27. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +128 -26
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +32 -7
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +406 -104
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  34. optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
  35. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
  36. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  37. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
  38. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
  39. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  40. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  41. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  42. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
  43. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  44. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  45. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  46. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  49. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  50. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +18 -12
  51. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  53. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  54. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +25 -16
  55. optimum/rbln/transformers/utils/__init__.py +0 -0
  56. optimum/rbln/transformers/utils/rbln_quantization.py +97 -0
  57. optimum/rbln/utils/import_utils.py +37 -5
  58. optimum/rbln/utils/logging.py +82 -0
  59. optimum/rbln/utils/runtime_utils.py +35 -1
  60. optimum/rbln/utils/timer_utils.py +19 -0
  61. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +15 -7
  62. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  63. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  64. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  65. optimum_rbln-0.1.8.dist-info/RECORD +0 -73
  66. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py CHANGED
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING
25
25
 
26
26
  from transformers.utils import _LazyModule
27
27
 
28
+ from .__version__ import __version__
28
29
  from .utils import check_version_compats
29
30
 
30
31
 
@@ -32,6 +33,7 @@ _import_structure = {
32
33
  "modeling_alias": [
33
34
  "RBLNASTForAudioClassification",
34
35
  "RBLNBertForQuestionAnswering",
36
+ "RBLNDistilBertForQuestionAnswering",
35
37
  "RBLNResNetForImageClassification",
36
38
  "RBLNT5ForConditionalGeneration",
37
39
  "RBLNBartForConditionalGeneration",
@@ -53,14 +55,32 @@ _import_structure = {
53
55
  ],
54
56
  "transformers": [
55
57
  "BatchTextIteratorStreamer",
58
+ "RBLNAutoModel",
59
+ "RBLNAutoModelForAudioClassification",
60
+ "RBLNAutoModelForCausalLM",
61
+ "RBLNAutoModelForCTC",
62
+ "RBLNAutoModelForDepthEstimation",
63
+ "RBLNAutoModelForImageClassification",
64
+ "RBLNAutoModelForMaskedLM",
65
+ "RBLNAutoModelForQuestionAnswering",
66
+ "RBLNAutoModelForSeq2SeqLM",
67
+ "RBLNAutoModelForSequenceClassification",
68
+ "RBLNAutoModelForSpeechSeq2Seq",
69
+ "RBLNAutoModelForVision2Seq",
70
+ "RBLNBartModel",
71
+ "RBLNBertModel",
56
72
  "RBLNCLIPTextModel",
57
73
  "RBLNCLIPTextModelWithProjection",
74
+ "RBLNCLIPVisionModel",
58
75
  "RBLNDPTForDepthEstimation",
59
76
  "RBLNGemmaForCausalLM",
60
77
  "RBLNGPT2LMHeadModel",
61
78
  "RBLNWav2Vec2ForCTC",
62
79
  "RBLNLlamaForCausalLM",
80
+ "RBLNPhiForCausalLM",
81
+ "RBLNLlavaNextForConditionalGeneration",
63
82
  "RBLNMidmLMHeadModel",
83
+ "RBLNMistralForCausalLM",
64
84
  "RBLNWhisperForConditionalGeneration",
65
85
  "RBLNXLMRobertaModel",
66
86
  ],
@@ -78,7 +98,7 @@ _import_structure = {
78
98
  "RBLNStableDiffusionXLControlNetPipeline",
79
99
  "RBLNStableDiffusionXLControlNetImg2ImgPipeline",
80
100
  ],
81
- "modeling_config": ["RBLNRuntimeConfig", "RBLNConfig"],
101
+ "modeling_config": ["RBLNCompileConfig", "RBLNConfig"],
82
102
  }
83
103
 
84
104
  if TYPE_CHECKING:
@@ -115,17 +135,35 @@ if TYPE_CHECKING:
115
135
  RBLNModelForQuestionAnswering,
116
136
  RBLNModelForSequenceClassification,
117
137
  )
118
- from .modeling_config import RBLNConfig, RBLNRuntimeConfig
138
+ from .modeling_config import RBLNCompileConfig, RBLNConfig
119
139
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
120
140
  from .transformers import (
121
141
  BatchTextIteratorStreamer,
142
+ RBLNAutoModel,
143
+ RBLNAutoModelForAudioClassification,
144
+ RBLNAutoModelForCausalLM,
145
+ RBLNAutoModelForCTC,
146
+ RBLNAutoModelForDepthEstimation,
147
+ RBLNAutoModelForImageClassification,
148
+ RBLNAutoModelForMaskedLM,
149
+ RBLNAutoModelForQuestionAnswering,
150
+ RBLNAutoModelForSeq2SeqLM,
151
+ RBLNAutoModelForSequenceClassification,
152
+ RBLNAutoModelForSpeechSeq2Seq,
153
+ RBLNAutoModelForVision2Seq,
154
+ RBLNBartModel,
155
+ RBLNBertModel,
122
156
  RBLNCLIPTextModel,
123
157
  RBLNCLIPTextModelWithProjection,
158
+ RBLNCLIPVisionModel,
124
159
  RBLNDPTForDepthEstimation,
125
160
  RBLNGemmaForCausalLM,
126
161
  RBLNGPT2LMHeadModel,
127
162
  RBLNLlamaForCausalLM,
163
+ RBLNLlavaNextForConditionalGeneration,
128
164
  RBLNMidmLMHeadModel,
165
+ RBLNMistralForCausalLM,
166
+ RBLNPhiForCausalLM,
129
167
  RBLNWav2Vec2ForCTC,
130
168
  RBLNWhisperForConditionalGeneration,
131
169
  RBLNXLMRobertaModel,
@@ -1 +1 @@
1
- __version__ = '0.1.8'
1
+ __version__ = '0.1.11'
@@ -23,10 +23,10 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
27
27
 
28
28
  import rebel
29
- import torch
29
+ import torch # noqa: I001
30
30
  from diffusers import AutoencoderKL
31
31
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
32
32
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
@@ -34,16 +34,16 @@ from optimum.exporters import TasksManager
34
34
  from transformers import AutoConfig, AutoModel, PretrainedConfig
35
35
 
36
36
  from ...modeling_base import RBLNModel
37
- from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
37
+ from ...modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
38
38
  from ...utils.runtime_utils import RBLNPytorchRuntime
39
39
 
40
40
 
41
- logger = logging.getLogger(__name__)
42
-
43
41
  if TYPE_CHECKING:
44
42
  import torch
45
43
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
46
44
 
45
+ logger = logging.getLogger(__name__)
46
+
47
47
 
48
48
  class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
49
49
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
@@ -63,10 +63,9 @@ class RBLNAutoencoderKL(RBLNModel):
63
63
  auto_model_class = AutoModel # feature extraction
64
64
 
65
65
  def __post_init__(self, **kwargs):
66
- self.dtype = torch.float32
67
-
68
- self.rbln_use_encode = self.rbln_config.meta["rbln_use_encode"]
66
+ super().__post_init__(**kwargs)
69
67
 
68
+ self.rbln_use_encode = self.rbln_config.model_cfg["use_encode"]
70
69
  if self.rbln_use_encode:
71
70
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
72
71
  self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
@@ -81,20 +80,20 @@ class RBLNAutoencoderKL(RBLNModel):
81
80
  encoder_model.eval()
82
81
  decoder_model.eval()
83
82
 
84
- enc_compiled_model = cls.compile(encoder_model, rbln_runtime_config=rbln_config["encoder"][0])
85
- dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["decoder"][0])
83
+ enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
84
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
86
85
 
87
- return enc_compiled_model, dec_compiled_model
86
+ return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
88
87
 
89
88
  def compile_text2img():
90
89
  decoder_model = _VAEDecoder(model)
91
90
  decoder_model.eval()
92
91
 
93
- dec_compiled_model = cls.compile(decoder_model, rbln_runtime_config=rbln_config["compiled_model"][0])
92
+ dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
94
93
 
95
94
  return dec_compiled_model
96
95
 
97
- if rbln_config.meta.get("rbln_use_encode", False):
96
+ if rbln_config.model_cfg.get("use_encode", False):
98
97
  return compile_img2img()
99
98
  else:
100
99
  return compile_text2img()
@@ -133,23 +132,23 @@ class RBLNAutoencoderKL(RBLNModel):
133
132
  cls,
134
133
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
135
134
  model_config: "PretrainedConfig",
136
- rbln_unet_sample_size: Optional[int] = None,
137
- rbln_img_width: Optional[int] = None,
138
- rbln_img_height: Optional[int] = None,
139
- rbln_batch_size: Optional[int] = None,
140
- rbln_use_encode: Optional[bool] = None,
141
- rbln_vae_scale_factor: Optional[int] = None,
135
+ rbln_kwargs: Dict[str, Any] = {},
142
136
  ) -> RBLNConfig:
143
- meta = {}
137
+ rbln_unet_sample_size = rbln_kwargs.get("unet_sample_size", None)
138
+ rbln_img_width = rbln_kwargs.get("img_width", None)
139
+ rbln_img_height = rbln_kwargs.get("img_height", None)
140
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
141
+ rbln_use_encode = rbln_kwargs.get("use_encode", None)
142
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
143
+
144
144
  if rbln_batch_size is None:
145
145
  rbln_batch_size = 1
146
146
 
147
- meta["rbln_use_encode"] = rbln_use_encode
148
- meta["rbln_batch_size"] = rbln_batch_size
147
+ model_cfg = {}
149
148
 
150
149
  if rbln_use_encode:
151
- meta["rbln_img_width"] = rbln_img_width
152
- meta["rbln_img_height"] = rbln_img_height
150
+ model_cfg["img_width"] = rbln_img_width
151
+ model_cfg["img_height"] = rbln_img_height
153
152
 
154
153
  vae_enc_input_info = [
155
154
  ("x", [rbln_batch_size, model_config.in_channels, rbln_img_height, rbln_img_width], "float32")
@@ -167,20 +166,23 @@ class RBLNAutoencoderKL(RBLNModel):
167
166
  )
168
167
  ]
169
168
 
170
- enc_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
171
- dec_rbln_runtime_config = RBLNRuntimeConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
169
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
170
+ dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
172
171
 
173
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
174
- [enc_rbln_runtime_config, dec_rbln_runtime_config],
175
- _rbln_meta=meta,
172
+ compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
173
+ rbln_config = RBLNConfig(
174
+ rbln_cls=cls.__name__,
175
+ compile_cfgs=compile_cfgs,
176
+ rbln_kwargs=rbln_kwargs,
176
177
  )
178
+ rbln_config.model_cfg.update(model_cfg)
177
179
  return rbln_config
178
180
 
179
181
  if rbln_unet_sample_size is None:
180
182
  rbln_unet_sample_size = 64
181
183
 
182
- meta["rbln_unet_sample_size"] = rbln_unet_sample_size
183
- vae_config = RBLNRuntimeConfig(
184
+ model_cfg["unet_sample_size"] = rbln_unet_sample_size
185
+ vae_config = RBLNCompileConfig(
184
186
  input_info=[
185
187
  (
186
188
  "z",
@@ -194,7 +196,12 @@ class RBLNAutoencoderKL(RBLNModel):
194
196
  )
195
197
  ],
196
198
  )
197
- rbln_config = RBLNConfig.from_rbln_runtime_configs([vae_config], _rbln_meta=meta)
199
+ rbln_config = RBLNConfig(
200
+ rbln_cls=cls.__name__,
201
+ compile_cfgs=[vae_config],
202
+ rbln_kwargs=rbln_kwargs,
203
+ )
204
+ rbln_config.model_cfg.update(model_cfg)
198
205
  return rbln_config
199
206
 
200
207
  @classmethod
@@ -23,7 +23,7 @@
23
23
 
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Dict, Optional, Union
26
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
27
27
 
28
28
  import torch
29
29
  from diffusers import ControlNetModel
@@ -31,15 +31,16 @@ from optimum.exporters import TasksManager
31
31
  from transformers import AutoConfig, AutoModel, PretrainedConfig
32
32
 
33
33
  from ...modeling_base import RBLNModel
34
- from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
34
+ from ...modeling_config import RBLNCompileConfig, RBLNConfig
35
35
 
36
36
 
37
- logger = logging.getLogger(__name__)
38
-
39
37
  if TYPE_CHECKING:
40
38
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
41
39
 
42
40
 
41
+ logger = logging.getLogger(__name__)
42
+
43
+
43
44
  class _ControlNetModel(torch.nn.Module):
44
45
  def __init__(self, controlnet: "ControlNetModel"):
45
46
  super().__init__()
@@ -108,21 +109,21 @@ class RBLNControlNetModel(RBLNModel):
108
109
  auto_model_class = AutoModel # feature extraction
109
110
 
110
111
  def __post_init__(self, **kwargs):
111
- self.dtype = torch.float32
112
+ super().__post_init__(**kwargs)
112
113
  self.use_encoder_hidden_states = any(
113
- item[0] == "encoder_hidden_states" for item in self.rbln_config["compiled_model"][0].input_info
114
+ item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
114
115
  )
115
116
 
116
117
  @classmethod
117
118
  def from_pretrained(cls, *args, **kwargs):
119
+ if "subfolder" in kwargs:
120
+ del kwargs["subfolder"]
121
+
118
122
  def get_model_from_task(
119
123
  task: str,
120
124
  model_name_or_path: Union[str, Path],
121
125
  **kwargs,
122
126
  ):
123
- if "subfolder" in kwargs:
124
- del kwargs["subfolder"]
125
-
126
127
  return ControlNetModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path, **kwargs)
127
128
 
128
129
  tasktmp = TasksManager.get_model_from_task
@@ -138,7 +139,7 @@ class RBLNControlNetModel(RBLNModel):
138
139
  return rt
139
140
 
140
141
  @classmethod
141
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
142
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
142
143
  use_encoder_hidden_states = False
143
144
  for down_block in model.down_blocks:
144
145
  if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
@@ -154,14 +155,14 @@ class RBLNControlNetModel(RBLNModel):
154
155
  cls,
155
156
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
156
157
  model_config: "PretrainedConfig",
157
- rbln_max_seq_len: Optional[int] = None,
158
- rbln_text_model_hidden_size: Optional[int] = None,
159
- rbln_batch_size: Optional[int] = None,
160
- rbln_img_width: Optional[int] = None,
161
- rbln_img_height: Optional[int] = None,
162
- rbln_vae_scale_factor: Optional[int] = None,
158
+ rbln_kwargs: Dict[str, Any] = {},
163
159
  ) -> RBLNConfig:
164
- meta = {"type": "controlnet"}
160
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
161
+ rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
162
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
163
+ rbln_img_width = rbln_kwargs.get("img_width", None)
164
+ rbln_img_height = rbln_kwargs.get("img_height", None)
165
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
165
166
 
166
167
  if rbln_batch_size is None:
167
168
  rbln_batch_size = 1
@@ -169,28 +170,29 @@ class RBLNControlNetModel(RBLNModel):
169
170
  if rbln_max_seq_len is None:
170
171
  rbln_max_seq_len = 77
171
172
 
173
+ if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
174
+ raise ValueError("rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided")
175
+
172
176
  input_width = rbln_img_width // rbln_vae_scale_factor
173
177
  input_height = rbln_img_height // rbln_vae_scale_factor
174
178
 
175
- rbln_runtime_config = RBLNRuntimeConfig(
176
- input_info=[
177
- (
178
- "sample",
179
- [
180
- rbln_batch_size,
181
- model_config.in_channels,
182
- input_height,
183
- input_width,
184
- ],
185
- "float32",
186
- ),
187
- ("timestep", [], "float32"),
188
- ],
189
- batch_size=rbln_batch_size,
190
- )
179
+ input_info = [
180
+ (
181
+ "sample",
182
+ [
183
+ rbln_batch_size,
184
+ model_config.in_channels,
185
+ input_height,
186
+ input_width,
187
+ ],
188
+ "float32",
189
+ ),
190
+ ("timestep", [], "float32"),
191
+ ]
192
+
191
193
  use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
192
194
  if use_encoder_hidden_states:
193
- rbln_runtime_config.input_info.append(
195
+ input_info.append(
194
196
  (
195
197
  "encoder_hidden_states",
196
198
  [
@@ -201,19 +203,34 @@ class RBLNControlNetModel(RBLNModel):
201
203
  "float32",
202
204
  )
203
205
  )
204
- rbln_runtime_config.input_info.append(
205
- ("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32")
206
- )
207
- rbln_runtime_config.input_info.append(("conditioning_scale", [], "float32"))
206
+
207
+ input_info.append(("controlnet_cond", [rbln_batch_size, 3, rbln_img_height, rbln_img_width], "float32"))
208
+ input_info.append(("conditioning_scale", [], "float32"))
209
+
208
210
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
209
211
  if rbln_text_model_hidden_size is None:
210
212
  rbln_text_model_hidden_size = 768
211
- rbln_runtime_config.input_info.append(
212
- ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
213
- )
214
- rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
213
+ input_info.append(("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32"))
214
+ input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
215
+
216
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
217
+
218
+ rbln_config = RBLNConfig(
219
+ rbln_cls=cls.__name__,
220
+ compile_cfgs=[rbln_compile_config],
221
+ rbln_kwargs=rbln_kwargs,
222
+ )
223
+
224
+ rbln_config.model_cfg.update(
225
+ {
226
+ "max_seq_len": rbln_max_seq_len,
227
+ "batch_size": rbln_batch_size,
228
+ "img_width": rbln_img_width,
229
+ "img_height": rbln_img_height,
230
+ "vae_scale_factor": rbln_vae_scale_factor,
231
+ }
232
+ )
215
233
 
216
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
217
234
  return rbln_config
218
235
 
219
236
  def forward(
@@ -32,14 +32,14 @@ from optimum.exporters import TasksManager
32
32
  from transformers import AutoConfig, AutoModel, PretrainedConfig
33
33
 
34
34
  from ...modeling_base import RBLNModel
35
- from ...modeling_config import RBLNConfig, RBLNRuntimeConfig
35
+ from ...modeling_config import RBLNCompileConfig, RBLNConfig
36
36
 
37
37
 
38
- logger = logging.getLogger(__name__)
39
-
40
38
  if TYPE_CHECKING:
41
39
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
40
 
41
+ logger = logging.getLogger(__name__)
42
+
43
43
 
44
44
  class _UNet_SD(torch.nn.Module):
45
45
  def __init__(self, unet: "UNet2DConditionModel"):
@@ -130,8 +130,8 @@ class RBLNUNet2DConditionModel(RBLNModel):
130
130
  auto_model_class = AutoModel # feature extraction
131
131
 
132
132
  def __post_init__(self, **kwargs):
133
- self.dtype = torch.float32
134
- self.in_features = self.rbln_config.meta.get("in_features", None)
133
+ super().__post_init__(**kwargs)
134
+ self.in_features = self.rbln_config.model_cfg.get("in_features", None)
135
135
  if self.in_features is not None:
136
136
 
137
137
  @dataclass
@@ -172,7 +172,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
172
172
  return rt
173
173
 
174
174
  @classmethod
175
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
175
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
176
176
  if model.config.addition_embed_type == "text_time":
177
177
  return _UNet_SDXL(model).eval()
178
178
  else:
@@ -183,31 +183,31 @@ class RBLNUNet2DConditionModel(RBLNModel):
183
183
  cls,
184
184
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
185
185
  model_config: "PretrainedConfig",
186
- rbln_max_seq_len: Optional[int] = None,
187
- rbln_text_model_hidden_size: Optional[int] = None,
188
- rbln_batch_size: Optional[int] = None,
189
- rbln_in_features: Optional[int] = None,
190
- rbln_use_encode: Optional[bool] = None,
191
- rbln_img_width: Optional[int] = None,
192
- rbln_img_height: Optional[int] = None,
193
- rbln_vae_scale_factor: Optional[int] = None,
194
- rbln_is_controlnet: Optional[bool] = None,
186
+ rbln_kwargs: Dict[str, Any] = {},
195
187
  ) -> RBLNConfig:
196
- meta = {"type": "unet"}
197
- if rbln_batch_size is None:
198
- rbln_batch_size = 1
188
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
189
+ rbln_text_model_hidden_size = rbln_kwargs.get("text_model_hidden_size", None)
190
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
191
+ rbln_in_features = rbln_kwargs.get("in_features", None)
192
+ rbln_use_encode = rbln_kwargs.get("use_encode", None)
193
+ rbln_img_width = rbln_kwargs.get("img_width", None)
194
+ rbln_img_height = rbln_kwargs.get("img_height", None)
195
+ rbln_vae_scale_factor = rbln_kwargs.get("vae_scale_factor", None)
196
+ rbln_is_controlnet = rbln_kwargs.get("is_controlnet", None)
199
197
 
200
198
  if rbln_max_seq_len is None:
201
199
  rbln_max_seq_len = 77
202
-
203
- meta["rbln_use_encode"] = rbln_use_encode
200
+ if rbln_batch_size is None:
201
+ rbln_batch_size = 1
204
202
 
205
203
  if rbln_use_encode:
206
- # FIXME :: robust img shape getter
204
+ if rbln_img_width is None or rbln_img_height is None or rbln_vae_scale_factor is None:
205
+ raise ValueError(
206
+ "rbln_img_width, rbln_img_height, and rbln_vae_scale_factor must be provided when rbln_use_encode is True"
207
+ )
207
208
  input_width = rbln_img_width // rbln_vae_scale_factor
208
209
  input_height = rbln_img_height // rbln_vae_scale_factor
209
210
  else:
210
- # FIXME :: model_config.sample_size can be tuple or list
211
211
  input_width, input_height = model_config.sample_size, model_config.sample_size
212
212
 
213
213
  input_info = [
@@ -232,6 +232,7 @@ class RBLNUNet2DConditionModel(RBLNModel):
232
232
  "float32",
233
233
  ),
234
234
  ]
235
+
235
236
  if rbln_is_controlnet:
236
237
  if len(model_config.block_out_channels) > 0:
237
238
  input_info.extend(
@@ -304,24 +305,35 @@ class RBLNUNet2DConditionModel(RBLNModel):
304
305
  )
305
306
  )
306
307
 
307
- rbln_runtime_config = RBLNRuntimeConfig(
308
- input_info=input_info,
309
- batch_size=rbln_batch_size,
310
- )
308
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
311
309
 
312
310
  if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
313
- # In case of sdxl
314
311
  if rbln_text_model_hidden_size is None:
315
312
  rbln_text_model_hidden_size = 768
316
313
  if rbln_in_features is None:
317
314
  rbln_in_features = model_config.projection_class_embeddings_input_dim
318
- meta["in_features"] = rbln_in_features
319
- rbln_runtime_config.input_info.append(
315
+ rbln_compile_config.input_info.append(
320
316
  ("text_embeds", [rbln_batch_size, rbln_text_model_hidden_size], "float32")
321
317
  )
322
- rbln_runtime_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
318
+ rbln_compile_config.input_info.append(("time_ids", [rbln_batch_size, 6], "float32"))
319
+
320
+ rbln_config = RBLNConfig(
321
+ rbln_cls=cls.__name__,
322
+ compile_cfgs=[rbln_compile_config],
323
+ rbln_kwargs=rbln_kwargs,
324
+ )
325
+
326
+ rbln_config.model_cfg.update(
327
+ {
328
+ "max_seq_len": rbln_max_seq_len,
329
+ "batch_size": rbln_batch_size,
330
+ "use_encode": rbln_use_encode,
331
+ }
332
+ )
333
+
334
+ if rbln_in_features is not None:
335
+ rbln_config.model_cfg["in_features"] = rbln_in_features
323
336
 
324
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
325
337
  return rbln_config
326
338
 
327
339
  def forward(
@@ -37,11 +37,11 @@ from ....modeling_config import RBLNConfig
37
37
  from ...models.controlnet import RBLNControlNetModel
38
38
 
39
39
 
40
- logger = logging.getLogger(__name__)
41
-
42
40
  if TYPE_CHECKING:
43
41
  pass
44
42
 
43
+ logger = logging.getLogger(__name__)
44
+
45
45
 
46
46
  class RBLNMultiControlNetModel(RBLNModel):
47
47
  def __init__(
@@ -79,7 +79,6 @@ class RBLNMultiControlNetModel(RBLNModel):
79
79
  model_id: Union[str, Path],
80
80
  **kwargs,
81
81
  ) -> RBLNModel:
82
-
83
82
  idx = 0
84
83
  controlnets = []
85
84
  model_path_to_load = model_id