optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,266 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ import os
26
+ from pathlib import Path
27
+ from tempfile import TemporaryDirectory
28
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
29
+
30
+ import rebel
31
+ import torch
32
+ from diffusers import ControlNetModel
33
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
34
+ from optimum.exporters import TasksManager
35
+ from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
36
+
37
+ from ....modeling_base import RBLNBaseModel
38
+ from ....modeling_config import RBLNConfig
39
+ from ...models.controlnet import RBLNControlNetModel
40
+
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ if TYPE_CHECKING:
45
+ from transformers import (
46
+ PretrainedConfig,
47
+ PreTrainedModel,
48
+ )
49
+
50
+
51
+ class RBLNMultiControlNetModel(RBLNBaseModel):
52
+ model_type = "rbln_model"
53
+ auto_model_class = AutoModel
54
+
55
+ def __init__(
56
+ self,
57
+ models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
58
+ config: PretrainedConfig = None,
59
+ preprocessors: Optional[List] = None,
60
+ rbln_config: Optional[RBLNConfig] = None,
61
+ **kwargs,
62
+ ):
63
+ super().__init__(
64
+ models,
65
+ config,
66
+ preprocessors,
67
+ rbln_config,
68
+ **kwargs,
69
+ )
70
+
71
+ if not isinstance(config, PretrainedConfig):
72
+ config = PretrainedConfig(**config)
73
+
74
+ for i in range(len(models)):
75
+ self.runtimes[i].config = config
76
+ self.nets = self.runtimes
77
+ self.dtype = torch.float32
78
+
79
+ @classmethod
80
+ def from_pretrained(cls, *args, **kwargs):
81
+ def get_model_from_task(
82
+ task: str,
83
+ model_name_or_path: Union[str, Path],
84
+ **kwargs,
85
+ ):
86
+ return MultiControlNetModel.from_pretrained(pretrained_model_path=model_name_or_path, **kwargs)
87
+
88
+ tasktmp = TasksManager.get_model_from_task
89
+ configtmp = AutoConfig.from_pretrained
90
+ modeltmp = AutoModel.from_pretrained
91
+ TasksManager.get_model_from_task = get_model_from_task
92
+ AutoConfig.from_pretrained = ControlNetModel.load_config
93
+ AutoModel.from_pretrained = MultiControlNetModel.from_pretrained
94
+ rt = super().from_pretrained(*args, **kwargs)
95
+ AutoConfig.from_pretrained = configtmp
96
+ AutoModel.from_pretrained = modeltmp
97
+ TasksManager.get_model_from_task = tasktmp
98
+ return rt
99
+
100
+ @classmethod
101
+ def _from_pretrained(
102
+ cls,
103
+ model_id: Union[str, Path],
104
+ config: "PretrainedConfig",
105
+ use_auth_token: Optional[Union[bool, str]] = None,
106
+ revision: Optional[str] = None,
107
+ force_download: bool = False,
108
+ cache_dir: Optional[str] = None,
109
+ file_name: Optional[str] = None,
110
+ subfolder: str = "",
111
+ local_files_only: bool = False,
112
+ **kwargs,
113
+ ) -> RBLNBaseModel:
114
+
115
+ if isinstance(model_id, str):
116
+ model_path = Path(model_id)
117
+ else:
118
+ model_path = model_id / "controlnet"
119
+
120
+ rbln_files = []
121
+ rbln_config_filenames = []
122
+ idx = 0
123
+ model_load_path = model_path
124
+
125
+ while model_load_path.is_dir():
126
+ rbln_files.append(list(model_load_path.glob("**/*.rbln"))[0])
127
+ rbln_config_filenames.append(model_load_path)
128
+ idx += 1
129
+ model_load_path = Path(str(model_path) + f"_{idx}")
130
+
131
+ if len(rbln_files) == 0:
132
+ raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
133
+
134
+ if len(rbln_config_filenames) == 0:
135
+ raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
136
+
137
+ models = []
138
+ for rconf, rfiles in zip(rbln_config_filenames, rbln_files):
139
+ rbln_config = RBLNConfig.load(str(rconf))
140
+ models.append(rebel.RBLNCompiledModel(rfiles))
141
+
142
+ preprocessors = []
143
+
144
+ return cls(
145
+ models,
146
+ config,
147
+ preprocessors,
148
+ rbln_config=rbln_config,
149
+ **kwargs,
150
+ )
151
+
152
+ def _save_pretrained(self, save_directory: Union[str, Path]):
153
+ idx = 0
154
+ real_save_dir_path = save_directory
155
+ for compiled_model in self.compiled_models:
156
+ dst_path = Path(real_save_dir_path) / "compiled_model.rbln"
157
+ if not os.path.exists(real_save_dir_path):
158
+ os.makedirs(real_save_dir_path)
159
+ compiled_model.save(dst_path)
160
+ self.rbln_config.save(real_save_dir_path)
161
+ idx += 1
162
+ real_save_dir_path = save_directory + f"_{idx}"
163
+
164
+ @classmethod
165
+ @torch.no_grad()
166
+ def _export(
167
+ cls,
168
+ model_id: str,
169
+ config: "PretrainedConfig",
170
+ use_auth_token: Optional[Union[bool, str]] = None,
171
+ revision: Optional[str] = None,
172
+ force_download: bool = False,
173
+ cache_dir: Optional[str] = None,
174
+ subfolder: str = "",
175
+ local_files_only: bool = False,
176
+ trust_remote_code: bool = False,
177
+ **kwargs,
178
+ ) -> "RBLNMultiControlNetModel":
179
+
180
+ task = kwargs.pop("task", None)
181
+ if task is None:
182
+ task = TasksManager.infer_task_from_model(cls.auto_model_class)
183
+
184
+ save_dir = TemporaryDirectory()
185
+ save_dir_path = Path(save_dir.name)
186
+
187
+ rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
188
+ img_width = rbln_config_kwargs.pop("rbln_img_width", None)
189
+ img_height = rbln_config_kwargs.pop("rbln_img_height", None)
190
+ vae_scale_factor = rbln_config_kwargs.pop("rbln_vae_scale_factor", None)
191
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
192
+
193
+ model: MultiControlNetModel = TasksManager.get_model_from_task(
194
+ task=task,
195
+ model_name_or_path=model_id,
196
+ )
197
+
198
+ model_path_to_load = model_id
199
+ real_save_dir_path = save_dir_path / "controlnet"
200
+
201
+ for idx in range(len(model.nets)):
202
+ suffix = "" if idx == 0 else f"_{idx}"
203
+ controlnet = RBLNControlNetModel.from_pretrained(
204
+ model_path_to_load + suffix,
205
+ export=True,
206
+ rbln_batch_size=batch_size,
207
+ rbln_img_width=img_width,
208
+ rbln_img_height=img_height,
209
+ rbln_vae_scale_factor=vae_scale_factor,
210
+ )
211
+ controlnet.save_pretrained(real_save_dir_path)
212
+ real_save_dir_path = save_dir_path / f"controlnet_{idx+1}"
213
+
214
+ return cls._from_pretrained(
215
+ model_id=save_dir_path,
216
+ config=config,
217
+ model_save_dir=save_dir,
218
+ **rbln_constructor_kwargs,
219
+ **kwargs,
220
+ )
221
+
222
+ def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
223
+ device_val = rbln_device_map["compiled_model"]
224
+
225
+ return [
226
+ compiled_model.create_runtime(tensor_type="pt", device=device_val)
227
+ for compiled_model in self.compiled_models
228
+ ]
229
+
230
+ def forward(
231
+ self,
232
+ sample: torch.FloatTensor,
233
+ timestep: Union[torch.Tensor, float, int],
234
+ encoder_hidden_states: torch.Tensor,
235
+ controlnet_cond: List[torch.tensor],
236
+ conditioning_scale: List[float],
237
+ class_labels: Optional[torch.Tensor] = None,
238
+ timestep_cond: Optional[torch.Tensor] = None,
239
+ attention_mask: Optional[torch.Tensor] = None,
240
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
241
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
242
+ guess_mode: bool = False,
243
+ return_dict: bool = True,
244
+ ):
245
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
246
+ output = controlnet(
247
+ sample=sample.contiguous(),
248
+ timestep=timestep,
249
+ encoder_hidden_states=encoder_hidden_states,
250
+ controlnet_cond=image,
251
+ conditioning_scale=torch.tensor(scale),
252
+ )
253
+
254
+ down_samples, mid_sample = output[:-1], output[-1]
255
+
256
+ # merge samples
257
+ if i == 0:
258
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
259
+ else:
260
+ down_block_res_samples = [
261
+ samples_prev + samples_curr
262
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
263
+ ]
264
+ mid_block_res_sample += mid_sample
265
+
266
+ return down_block_res_samples, mid_block_res_sample
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .pipeline_controlnet_img2img import RBLNStableDiffusionControlNetImg2ImgPipeline
25
+ from .pipeline_stable_diffusion import RBLNStableDiffusionPipeline
26
+ from .pipeline_stable_diffusion_img2img import RBLNStableDiffusionImg2ImgPipeline