optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. optimum/rbln/__init__.py +164 -36
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +107 -78
  31. optimum/rbln/transformers/__init__.py +87 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +108 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  73. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  74. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  75. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  76. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  77. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  78. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  79. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  80. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  81. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  82. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  83. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  84. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  85. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  86. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  87. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  88. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  89. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  90. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  91. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  92. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  93. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  94. optimum/rbln/utils/runtime_utils.py +33 -2
  95. optimum/rbln/utils/submodule.py +26 -43
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
  97. optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
  98. optimum/rbln/modeling_config.py +0 -310
  99. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  100. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
  101. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -12,73 +12,72 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Dict, List, Tuple, Union
16
16
 
17
17
  import rebel
18
- import torch # noqa: I001
18
+ import torch
19
19
  from diffusers import AutoencoderKL
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
20
21
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
21
22
  from transformers import PretrainedConfig
22
23
 
24
+ from ....configuration_utils import RBLNCompileConfig
23
25
  from ....modeling import RBLNModel
24
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
25
26
  from ....utils.logging import get_logger
26
- from ...modeling_diffusers import RBLNDiffusionMixin
27
+ from ...configurations import RBLNAutoencoderKLConfig
27
28
  from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
28
29
 
29
30
 
30
31
  if TYPE_CHECKING:
31
32
  import torch
32
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
33
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
34
+
35
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
33
36
 
34
37
  logger = get_logger(__name__)
35
38
 
36
39
 
37
40
  class RBLNAutoencoderKL(RBLNModel):
38
41
  auto_model_class = AutoencoderKL
39
- config_name = "config.json"
40
42
  hf_library_name = "diffusers"
43
+ _rbln_config_class = RBLNAutoencoderKLConfig
41
44
 
42
45
  def __post_init__(self, **kwargs):
43
46
  super().__post_init__(**kwargs)
44
47
 
45
- if self.rbln_config.model_cfg.get("img2img_pipeline") or self.rbln_config.model_cfg.get("inpaint_pipeline"):
48
+ if self.rbln_config.uses_encoder:
46
49
  self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
47
- self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[1], main_input_name="z")
48
50
  else:
49
- self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[0], main_input_name="z")
51
+ self.encoder = None
50
52
 
51
- self.image_size = self.rbln_config.model_cfg["sample_size"]
53
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
54
+ self.image_size = self.rbln_config.image_size
52
55
 
53
56
  @classmethod
54
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
55
- def compile_img2img():
56
- encoder_model = _VAEEncoder(model)
57
- decoder_model = _VAEDecoder(model)
58
- encoder_model.eval()
59
- decoder_model.eval()
60
-
61
- enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
62
- dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
57
+ def get_compiled_model(cls, model, rbln_config: RBLNAutoencoderKLConfig) -> Dict[str, rebel.RBLNCompiledModel]:
58
+ if rbln_config.uses_encoder:
59
+ expected_models = ["encoder", "decoder"]
60
+ else:
61
+ expected_models = ["decoder"]
63
62
 
64
- return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
63
+ compiled_models = {}
64
+ for i, model_name in enumerate(expected_models):
65
+ if model_name == "encoder":
66
+ wrapped_model = _VAEEncoder(model)
67
+ else:
68
+ wrapped_model = _VAEDecoder(model)
65
69
 
66
- def compile_text2img():
67
- decoder_model = _VAEDecoder(model)
68
- decoder_model.eval()
70
+ wrapped_model.eval()
69
71
 
70
- dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
72
+ compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
71
73
 
72
- return dec_compiled_model
73
-
74
- if rbln_config.model_cfg.get("img2img_pipeline") or rbln_config.model_cfg.get("inpaint_pipeline"):
75
- return compile_img2img()
76
- else:
77
- return compile_text2img()
74
+ return compiled_models
78
75
 
79
76
  @classmethod
80
- def get_vae_sample_size(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
81
- image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
77
+ def get_vae_sample_size(
78
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: RBLNAutoencoderKLConfig, return_vae_scale_factor: bool = False
79
+ ) -> Tuple[int, int]:
80
+ sample_size = rbln_config.sample_size
82
81
  noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
83
82
  vae_scale_factor = (
84
83
  pipe.vae_scale_factor
@@ -91,139 +90,121 @@ class RBLNAutoencoderKL(RBLNModel):
91
90
  "Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
92
91
  )
93
92
 
94
- if (image_size[0] is None) != (image_size[1] is None):
95
- raise ValueError("Both image height and image width must be given or not given")
93
+ if sample_size is None:
94
+ sample_size = noise_module.config.sample_size
95
+ if isinstance(sample_size, int):
96
+ sample_size = (sample_size, sample_size)
97
+ sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
96
98
 
97
- elif image_size[0] is None and image_size[1] is None:
98
- if rbln_config["img2img_pipeline"]:
99
- sample_size = noise_module.config.sample_size
100
- elif rbln_config["inpaint_pipeline"]:
101
- sample_size = noise_module.config.sample_size * vae_scale_factor
102
- else:
103
- # In case of text2img, sample size of vae decoder is determined by unet.
104
- noise_module_sample_size = noise_module.config.sample_size
105
- if isinstance(noise_module_sample_size, int):
106
- sample_size = noise_module_sample_size * vae_scale_factor
107
- else:
108
- sample_size = (
109
- noise_module_sample_size[0] * vae_scale_factor,
110
- noise_module_sample_size[1] * vae_scale_factor,
111
- )
99
+ if return_vae_scale_factor:
100
+ return sample_size, vae_scale_factor
112
101
  else:
113
- sample_size = (image_size[0], image_size[1])
114
-
115
- return sample_size
102
+ return sample_size
116
103
 
117
104
  @classmethod
118
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
119
- rbln_config.update({"sample_size": cls.get_vae_sample_size(pipe, rbln_config)})
105
+ def update_rbln_config_using_pipe(
106
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
107
+ ) -> "RBLNDiffusionMixinConfig":
108
+ rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
109
+ pipe, rbln_config.vae, return_vae_scale_factor=True
110
+ )
120
111
  return rbln_config
121
112
 
122
113
  @classmethod
123
- def _get_rbln_config(
114
+ def _update_rbln_config(
124
115
  cls,
125
116
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
117
+ model: "PreTrainedModel",
126
118
  model_config: "PretrainedConfig",
127
- rbln_kwargs: Dict[str, Any] = {},
128
- ) -> RBLNConfig:
129
- rbln_batch_size = rbln_kwargs.get("batch_size")
130
- sample_size = rbln_kwargs.get("sample_size")
131
- is_img2img = rbln_kwargs.get("img2img_pipeline")
132
- is_inpaint = rbln_kwargs.get("inpaint_pipeline")
133
-
134
- if rbln_batch_size is None:
135
- rbln_batch_size = 1
136
-
137
- if sample_size is None:
138
- sample_size = model_config.sample_size
119
+ rbln_config: RBLNAutoencoderKLConfig,
120
+ ) -> RBLNAutoencoderKLConfig:
121
+ if rbln_config.sample_size is None:
122
+ rbln_config.sample_size = model_config.sample_size
139
123
 
140
- if isinstance(sample_size, int):
141
- sample_size = (sample_size, sample_size)
124
+ if isinstance(rbln_config.sample_size, int):
125
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
142
126
 
143
- rbln_kwargs["sample_size"] = sample_size
127
+ if rbln_config.in_channels is None:
128
+ rbln_config.in_channels = model_config.in_channels
144
129
 
145
- if hasattr(model_config, "block_out_channels"):
146
- vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
147
- else:
148
- # vae image processor default value 8 (int)
149
- vae_scale_factor = 8
130
+ if rbln_config.latent_channels is None:
131
+ rbln_config.latent_channels = model_config.latent_channels
150
132
 
151
- dec_shape = (sample_size[0] // vae_scale_factor, sample_size[1] // vae_scale_factor)
152
- enc_shape = (sample_size[0], sample_size[1])
133
+ if rbln_config.vae_scale_factor is None:
134
+ if hasattr(model_config, "block_out_channels"):
135
+ rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
136
+ else:
137
+ # vae image processor default value 8 (int)
138
+ rbln_config.vae_scale_factor = 8
153
139
 
154
- if is_img2img or is_inpaint:
140
+ compile_cfgs = []
141
+ if rbln_config.uses_encoder:
155
142
  vae_enc_input_info = [
156
143
  (
157
144
  "x",
158
- [rbln_batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
145
+ [
146
+ rbln_config.batch_size,
147
+ rbln_config.in_channels,
148
+ rbln_config.sample_size[0],
149
+ rbln_config.sample_size[1],
150
+ ],
159
151
  "float32",
160
152
  )
161
153
  ]
162
- vae_dec_input_info = [
163
- (
164
- "z",
165
- [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
166
- "float32",
167
- )
168
- ]
169
-
170
- enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info)
171
- dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info)
172
-
173
- compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
174
- rbln_config = RBLNConfig(
175
- rbln_cls=cls.__name__,
176
- compile_cfgs=compile_cfgs,
177
- rbln_kwargs=rbln_kwargs,
154
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
155
+
156
+ vae_dec_input_info = [
157
+ (
158
+ "z",
159
+ [
160
+ rbln_config.batch_size,
161
+ rbln_config.latent_channels,
162
+ rbln_config.latent_sample_size[0],
163
+ rbln_config.latent_sample_size[1],
164
+ ],
165
+ "float32",
178
166
  )
179
- return rbln_config
167
+ ]
168
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
180
169
 
181
- vae_config = RBLNCompileConfig(
182
- input_info=[
183
- (
184
- "z",
185
- [rbln_batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
186
- "float32",
187
- )
188
- ]
189
- )
190
- rbln_config = RBLNConfig(
191
- rbln_cls=cls.__name__,
192
- compile_cfgs=[vae_config],
193
- rbln_kwargs=rbln_kwargs,
194
- )
170
+ rbln_config.set_compile_cfgs(compile_cfgs)
195
171
  return rbln_config
196
172
 
197
173
  @classmethod
198
174
  def _create_runtimes(
199
175
  cls,
200
176
  compiled_models: List[rebel.RBLNCompiledModel],
201
- rbln_device_map: Dict[str, int],
202
- activate_profiler: Optional[bool] = None,
177
+ rbln_config: RBLNAutoencoderKLConfig,
203
178
  ) -> List[rebel.Runtime]:
204
179
  if len(compiled_models) == 1:
205
- if DEFAULT_COMPILED_MODEL_NAME not in rbln_device_map:
206
- cls._raise_missing_compiled_file_error([DEFAULT_COMPILED_MODEL_NAME])
207
-
208
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
209
- return [
210
- compiled_models[0].create_runtime(
211
- tensor_type="pt", device=device_val, activate_profiler=activate_profiler
212
- )
213
- ]
180
+ # decoder
181
+ expected_models = ["decoder"]
182
+ else:
183
+ # encoder, decoder
184
+ expected_models = ["encoder", "decoder"]
214
185
 
215
- if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
216
- cls._raise_missing_compiled_file_error(["encoder", "decoder"])
186
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
187
+ cls._raise_missing_compiled_file_error(expected_models)
217
188
 
218
- device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
189
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
219
190
  return [
220
- compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
191
+ rebel.Runtime(
192
+ compiled_model,
193
+ tensor_type="pt",
194
+ device=device_val,
195
+ activate_profiler=rbln_config.activate_profiler,
196
+ )
221
197
  for compiled_model, device_val in zip(compiled_models, device_vals)
222
198
  ]
223
199
 
224
- def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
200
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
225
201
  posterior = self.encoder.encode(x)
202
+ if not return_dict:
203
+ return (posterior,)
226
204
  return AutoencoderKLOutput(latent_dist=posterior)
227
205
 
228
- def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
229
- return self.decoder.decode(z)
206
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
207
+ dec = self.decoder.decode(z)
208
+ if not return_dict:
209
+ return (dec,)
210
+ return DecoderOutput(sample=dec)
@@ -14,11 +14,9 @@
14
14
 
15
15
  from typing import TYPE_CHECKING, List
16
16
 
17
- import torch # noqa: I001
17
+ import torch
18
18
  from diffusers import AutoencoderKL, VQModel
19
19
  from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
20
- from diffusers.models.autoencoders.vq_model import VQEncoderOutput
21
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
20
 
23
21
  from ....utils.logging import get_logger
24
22
  from ....utils.runtime_utils import RBLNPytorchRuntime
@@ -34,12 +32,12 @@ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
34
32
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
35
33
  moments = self.forward(x.contiguous())
36
34
  posterior = DiagonalGaussianDistribution(moments)
37
- return AutoencoderKLOutput(latent_dist=posterior)
35
+ return posterior
38
36
 
39
37
 
40
38
  class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
41
39
  def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
42
- return (self.forward(z),)
40
+ return self.forward(z)
43
41
 
44
42
 
45
43
  class _VAEDecoder(torch.nn.Module):
@@ -78,7 +76,7 @@ class _VAEEncoder(torch.nn.Module):
78
76
  class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
79
77
  def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
80
78
  h = self.forward(x.contiguous())
81
- return VQEncoderOutput(latents=h)
79
+ return h
82
80
 
83
81
 
84
82
  class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
@@ -12,24 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
15
+ from typing import TYPE_CHECKING, List, Union
16
16
 
17
17
  import rebel
18
18
  import torch
19
19
  from diffusers import VQModel
20
20
  from diffusers.models.autoencoders.vae import DecoderOutput
21
21
  from diffusers.models.autoencoders.vq_model import VQEncoderOutput
22
- from transformers import PretrainedConfig
23
22
 
23
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
24
  from ....modeling import RBLNModel
25
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
26
25
  from ....utils.logging import get_logger
27
- from ...modeling_diffusers import RBLNDiffusionMixin
26
+ from ...configurations.models.configuration_vq_model import RBLNVQModelConfig
27
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
28
28
  from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
29
29
 
30
30
 
31
31
  if TYPE_CHECKING:
32
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
32
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
33
33
 
34
34
  logger = get_logger(__name__)
35
35
 
@@ -42,126 +42,125 @@ class RBLNVQModel(RBLNModel):
42
42
  def __post_init__(self, **kwargs):
43
43
  super().__post_init__(**kwargs)
44
44
 
45
- self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
46
- self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[1], main_input_name="z")
45
+ if self.rbln_config.uses_encoder:
46
+ self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
47
+ else:
48
+ self.encoder = None
49
+
50
+ self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[-1], main_input_name="z")
47
51
  self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
48
- height = self.rbln_config.model_cfg.get("img_height", 512)
49
- width = self.rbln_config.model_cfg.get("img_width", 512)
50
- self.image_size = [height, width]
52
+ self.image_size = self.rbln_config.image_size
51
53
 
52
54
  @classmethod
53
- def get_compiled_model(cls, model, rbln_config: RBLNConfig):
54
- encoder_model = _VQEncoder(model)
55
- decoder_model = _VQDecoder(model)
56
- encoder_model.eval()
57
- decoder_model.eval()
55
+ def get_compiled_model(cls, model, rbln_config: RBLNModelConfig):
56
+ if rbln_config.uses_encoder:
57
+ expected_models = ["encoder", "decoder"]
58
+ else:
59
+ expected_models = ["decoder"]
58
60
 
59
- enc_compiled_model = cls.compile(encoder_model, rbln_compile_config=rbln_config.compile_cfgs[0])
60
- dec_compiled_model = cls.compile(decoder_model, rbln_compile_config=rbln_config.compile_cfgs[1])
61
+ compiled_models = {}
62
+ for i, model_name in enumerate(expected_models):
63
+ if model_name == "encoder":
64
+ wrapped_model = _VQEncoder(model)
65
+ else:
66
+ wrapped_model = _VQDecoder(model)
61
67
 
62
- return {"encoder": enc_compiled_model, "decoder": dec_compiled_model}
68
+ wrapped_model.eval()
63
69
 
64
- @classmethod
65
- def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
66
- batch_size = rbln_config.get("batch_size")
67
- if batch_size is None:
68
- batch_size = 1
69
- img_height = rbln_config.get("img_height")
70
- if img_height is None:
71
- img_height = 512
72
- img_width = rbln_config.get("img_width")
73
- if img_width is None:
74
- img_width = 512
75
-
76
- rbln_config.update(
77
- {
78
- "batch_size": batch_size,
79
- "img_height": img_height,
80
- "img_width": img_width,
81
- }
82
- )
70
+ compiled_models[model_name] = cls.compile(wrapped_model, rbln_compile_config=rbln_config.compile_cfgs[i])
71
+
72
+ return compiled_models
83
73
 
74
+ @classmethod
75
+ def update_rbln_config_using_pipe(
76
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
77
+ ) -> "RBLNDiffusionMixinConfig":
84
78
  return rbln_config
85
79
 
86
80
  @classmethod
87
- def _get_rbln_config(
81
+ def _update_rbln_config(
88
82
  cls,
89
83
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
84
+ model: "PreTrainedModel",
90
85
  model_config: "PretrainedConfig",
91
- rbln_kwargs: Dict[str, Any] = {},
92
- ) -> RBLNConfig:
93
- batch_size = rbln_kwargs.get("batch_size")
94
- if batch_size is None:
95
- batch_size = 1
96
-
97
- height = rbln_kwargs.get("img_height")
98
- if height is None:
99
- height = 512
100
-
101
- width = rbln_kwargs.get("img_width")
102
- if width is None:
103
- width = 512
104
-
86
+ rbln_config: RBLNVQModelConfig,
87
+ ) -> RBLNVQModelConfig:
105
88
  if hasattr(model_config, "block_out_channels"):
106
- scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
89
+ rbln_config.vqmodel_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
107
90
  else:
108
91
  # image processor default value 8 (int)
109
- scale_factor = 8
110
-
111
- enc_shape = (height, width)
112
- dec_shape = (height // scale_factor, width // scale_factor)
92
+ rbln_config.vqmodel_scale_factor = 8
93
+
94
+ compile_cfgs = []
95
+ if rbln_config.uses_encoder:
96
+ enc_input_info = [
97
+ (
98
+ "x",
99
+ [
100
+ rbln_config.batch_size,
101
+ model_config.in_channels,
102
+ rbln_config.sample_size[0],
103
+ rbln_config.sample_size[1],
104
+ ],
105
+ "float32",
106
+ )
107
+ ]
108
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
109
+ compile_cfgs.append(enc_rbln_compile_config)
113
110
 
114
- enc_input_info = [
115
- (
116
- "x",
117
- [batch_size, model_config.in_channels, enc_shape[0], enc_shape[1]],
118
- "float32",
119
- )
120
- ]
121
111
  dec_input_info = [
122
112
  (
123
113
  "h",
124
- [batch_size, model_config.latent_channels, dec_shape[0], dec_shape[1]],
114
+ [
115
+ rbln_config.batch_size,
116
+ model_config.latent_channels,
117
+ rbln_config.latent_sample_size[0],
118
+ rbln_config.latent_sample_size[1],
119
+ ],
125
120
  "float32",
126
121
  )
127
122
  ]
128
-
129
- enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
130
123
  dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
124
+ compile_cfgs.append(dec_rbln_compile_config)
131
125
 
132
- compile_cfgs = [enc_rbln_compile_config, dec_rbln_compile_config]
133
- rbln_config = RBLNConfig(
134
- rbln_cls=cls.__name__,
135
- compile_cfgs=compile_cfgs,
136
- rbln_kwargs=rbln_kwargs,
137
- )
126
+ rbln_config.set_compile_cfgs(compile_cfgs)
138
127
  return rbln_config
139
128
 
140
129
  @classmethod
141
130
  def _create_runtimes(
142
131
  cls,
143
132
  compiled_models: List[rebel.RBLNCompiledModel],
144
- rbln_device_map: Dict[str, int],
145
- activate_profiler: Optional[bool] = None,
133
+ rbln_config: RBLNVQModelConfig,
146
134
  ) -> List[rebel.Runtime]:
147
135
  if len(compiled_models) == 1:
148
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
149
- return [
150
- compiled_models[0].create_runtime(
151
- tensor_type="pt", device=device_val, activate_profiler=activate_profiler
152
- )
153
- ]
136
+ # decoder
137
+ expected_models = ["decoder"]
138
+ else:
139
+ # encoder, decoder
140
+ expected_models = ["encoder", "decoder"]
154
141
 
155
- device_vals = [rbln_device_map["encoder"], rbln_device_map["decoder"]]
142
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
143
+ cls._raise_missing_compiled_file_error(expected_models)
144
+
145
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
156
146
  return [
157
- compiled_model.create_runtime(tensor_type="pt", device=device_val, activate_profiler=activate_profiler)
147
+ rebel.Runtime(
148
+ compiled_model,
149
+ tensor_type="pt",
150
+ device=device_val,
151
+ activate_profiler=rbln_config.activate_profiler,
152
+ )
158
153
  for compiled_model, device_val in zip(compiled_models, device_vals)
159
154
  ]
160
155
 
161
- def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
156
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
162
157
  posterior = self.encoder.encode(x)
158
+ if not return_dict:
159
+ return (posterior,)
163
160
  return VQEncoderOutput(latents=posterior)
164
161
 
165
- def decode(self, h: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
162
+ def decode(self, h: torch.FloatTensor, return_dict: bool = True, **kwargs) -> torch.FloatTensor:
166
163
  dec, commit_loss = self.decoder.decode(h, **kwargs)
164
+ if not return_dict:
165
+ return (dec, commit_loss)
167
166
  return DecoderOutput(sample=dec, commit_loss=commit_loss)