optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
- optimum/rbln/diffusers/models/controlnet.py +36 -62
- optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
- optimum/rbln/diffusers/pipelines/__init__.py +40 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +117 -144
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/modeling_diffusers.py +400 -0
- optimum/rbln/transformers/__init__.py +10 -0
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -28
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
- optimum/rbln/utils/context.py +58 -0
- optimum/rbln/utils/decorator_utils.py +55 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +4 -4
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
- optimum_rbln-0.1.13.dist-info/RECORD +107 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/RECORD +0 -93
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_config.py
CHANGED
@@ -242,3 +242,54 @@ class RBLNConfig:
|
|
242
242
|
rbln_device_map[cfg.compiled_model_name] = device_val
|
243
243
|
return rbln_device_map
|
244
244
|
return self.runtime_cfg["device_map"]
|
245
|
+
|
246
|
+
|
247
|
+
def use_rbln_config(fn):
|
248
|
+
"""
|
249
|
+
If the function uses rbln_config and kwargs,
|
250
|
+
then extract `rbln_` prefix from kwargs.
|
251
|
+
|
252
|
+
If rbln_config is already an instance of RBLNConfig, then pass.
|
253
|
+
"""
|
254
|
+
|
255
|
+
def merged_rbln_config_fn(*args, **kwargs):
|
256
|
+
rbln_kwargs = kwargs.pop("rbln_kwargs", None)
|
257
|
+
if rbln_kwargs is not None:
|
258
|
+
raise KeyError("`rbln_kwargs` cannot be specified when using `rbln_config`!")
|
259
|
+
|
260
|
+
rbln_config = kwargs.pop("rbln_config", None)
|
261
|
+
|
262
|
+
keys = list(kwargs.keys())
|
263
|
+
rbln_kwargs = {key[5:]: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
|
264
|
+
|
265
|
+
if isinstance(rbln_config, RBLNConfig):
|
266
|
+
# merge runtime kwargs if exists.
|
267
|
+
runtime_rbln_kwargs = {k: rbln_kwargs.pop(k) for k in RUNTIME_KEYWORDS if k in rbln_kwargs}
|
268
|
+
|
269
|
+
# ignore internal keys and recover "rbln_" prefix
|
270
|
+
RBLN_INTERNAL_KEYS = {"compiled_models", "submodules"}
|
271
|
+
internal_kwargs = {"rbln_" + k: rbln_kwargs.pop(k) for k in RBLN_INTERNAL_KEYS if k in rbln_kwargs}
|
272
|
+
|
273
|
+
if len(rbln_kwargs) > 0:
|
274
|
+
raise KeyError(
|
275
|
+
f"Failed to merging function argument : {rbln_kwargs.keys()}. "
|
276
|
+
"If you passed `rbln_config` an instance of `RBLNConfig`, "
|
277
|
+
"then none `rbln_` prefixes are allowed to be passed."
|
278
|
+
)
|
279
|
+
rbln_config.update_runtime_cfg(runtime_rbln_kwargs)
|
280
|
+
return fn(*args, **kwargs, **internal_kwargs, rbln_config=rbln_config)
|
281
|
+
|
282
|
+
elif rbln_config is None:
|
283
|
+
rbln_config_dict = {}
|
284
|
+
|
285
|
+
else:
|
286
|
+
rbln_config_dict = rbln_config
|
287
|
+
|
288
|
+
for key in rbln_config_dict:
|
289
|
+
if key in rbln_kwargs:
|
290
|
+
raise KeyError(f"Duplicated key in both `rbln_config` and rbln_{key}.")
|
291
|
+
|
292
|
+
rbln_kwargs.update(rbln_config_dict)
|
293
|
+
return fn(*args, **kwargs, rbln_config=rbln_kwargs)
|
294
|
+
|
295
|
+
return merged_rbln_config_fn
|
@@ -0,0 +1,400 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
import importlib
|
24
|
+
from os import PathLike
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
|
29
|
+
from .modeling_base import RBLNModel
|
30
|
+
from .modeling_config import ContextRblnConfig, use_rbln_config
|
31
|
+
from .utils.decorator_utils import remove_compile_time_kwargs
|
32
|
+
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
36
|
+
|
37
|
+
|
38
|
+
class RBLNDiffusionMixin:
|
39
|
+
"""
|
40
|
+
RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
|
41
|
+
This mixin class serves as a base for implementing RBLN-compatible Stable Diffusion pipelines. It contains shared logic for
|
42
|
+
handling the core components of Stable Diffusion.
|
43
|
+
|
44
|
+
To use this mixin:
|
45
|
+
|
46
|
+
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
47
|
+
2. Define the required _submodules class variable listing the components to be compiled.
|
48
|
+
3. If needed, implement get_default_rbln_config for custom configuration of submodules.
|
49
|
+
|
50
|
+
Example:
|
51
|
+
```python
|
52
|
+
class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
|
53
|
+
_submodules = ["text_encoder", "unet", "vae"]
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def get_default_rbln_config(cls, model, submodule_name, rbln_config):
|
57
|
+
# Configuration for other submodules...
|
58
|
+
pass
|
59
|
+
```
|
60
|
+
|
61
|
+
Class Variables:
|
62
|
+
_submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
|
63
|
+
|
64
|
+
Methods:
|
65
|
+
from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
|
66
|
+
|
67
|
+
Notes:
|
68
|
+
- When `export=True`, all compatible submodules will be compiled for NPU inference
|
69
|
+
- The compilation config can be customized per submodule by including submodule names
|
70
|
+
as keys in rbln_config
|
71
|
+
"""
|
72
|
+
|
73
|
+
_submodules = []
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
@property
|
77
|
+
def use_encode(cls):
|
78
|
+
return "Img2Img" in cls.__name__
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def _get_unet_batch_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> int:
|
82
|
+
# Calculates the batch size based on guidance scale
|
83
|
+
batch_size = rbln_config.get("batch_size", 1)
|
84
|
+
do_guidance = rbln_config.get("guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
|
85
|
+
return batch_size * 2 if do_guidance else batch_size
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def _get_vae_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
|
89
|
+
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
90
|
+
if (image_size[0] is None) != (image_size[1] is None):
|
91
|
+
raise ValueError("Both image height and image width must be given or not given")
|
92
|
+
elif image_size[0] is None and image_size[1] is None:
|
93
|
+
if cls.use_encode:
|
94
|
+
sample_size = model.vae.config.sample_size
|
95
|
+
else:
|
96
|
+
# In case of text2img, sample size of vae decoder is determined by unet.
|
97
|
+
unet_sample_size = model.unet.config.sample_size
|
98
|
+
if isinstance(unet_sample_size, int):
|
99
|
+
sample_size = unet_sample_size * model.vae_scale_factor
|
100
|
+
else:
|
101
|
+
sample_size = (
|
102
|
+
unet_sample_size[0] * model.vae_scale_factor,
|
103
|
+
unet_sample_size[1] * model.vae_scale_factor,
|
104
|
+
)
|
105
|
+
|
106
|
+
else:
|
107
|
+
sample_size = (image_size[0], image_size[1])
|
108
|
+
return sample_size
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def _get_unet_sample_size(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Union[int, Tuple[int, int]]:
|
112
|
+
image_size = (rbln_config.get("img_height"), rbln_config.get("img_width"))
|
113
|
+
if (image_size[0] is None) != (image_size[1] is None):
|
114
|
+
raise ValueError("Both image height and image width must be given or not given")
|
115
|
+
elif image_size[0] is None and image_size[1] is None:
|
116
|
+
if cls.use_encode:
|
117
|
+
# In case of img2img, sample size of unet is determined by vae encoder.
|
118
|
+
vae_sample_size = model.vae.config.sample_size
|
119
|
+
if isinstance(vae_sample_size, int):
|
120
|
+
sample_size = vae_sample_size // model.vae_scale_factor
|
121
|
+
else:
|
122
|
+
sample_size = (
|
123
|
+
vae_sample_size[0] // model.vae_scale_factor,
|
124
|
+
vae_sample_size[1] // model.vae_scale_factor,
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
sample_size = model.unet.config.sample_size
|
128
|
+
else:
|
129
|
+
sample_size = (image_size[0] // model.vae_scale_factor, image_size[1] // model.vae_scale_factor)
|
130
|
+
return sample_size
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def _get_default_config(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
134
|
+
# default configurations for each submodules
|
135
|
+
return {"img2img_pipeline": cls.use_encode}
|
136
|
+
|
137
|
+
@classmethod
|
138
|
+
def get_default_rbln_config_text_encoder(
|
139
|
+
cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
|
140
|
+
) -> Dict[str, Any]:
|
141
|
+
batch_size = rbln_config.get("batch_size", 1)
|
142
|
+
return {"batch_size": batch_size}
|
143
|
+
|
144
|
+
@classmethod
|
145
|
+
def get_default_rbln_config_text_encoder_2(
|
146
|
+
cls, model: torch.nn.Module, rbln_config: Dict[str, Any]
|
147
|
+
) -> Dict[str, Any]:
|
148
|
+
batch_size = rbln_config.get("batch_size", 1)
|
149
|
+
return {"batch_size": batch_size}
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def get_default_rbln_config_unet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
153
|
+
# configuration for unet
|
154
|
+
unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
|
155
|
+
text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
|
156
|
+
return {
|
157
|
+
**cls._get_default_config(model, rbln_config),
|
158
|
+
"max_seq_len": model.text_encoder.config.max_position_embeddings,
|
159
|
+
"text_model_hidden_size": text_model_hidden_size,
|
160
|
+
"batch_size": unet_batch_size,
|
161
|
+
"sample_size": cls._get_unet_sample_size(model, rbln_config),
|
162
|
+
"is_controlnet": "controlnet" in model.config.keys(),
|
163
|
+
}
|
164
|
+
|
165
|
+
@classmethod
|
166
|
+
def get_default_rbln_config_vae(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
167
|
+
# configuration for vae
|
168
|
+
batch_size = rbln_config.get("batch_size", 1)
|
169
|
+
return {
|
170
|
+
**cls._get_default_config(model, rbln_config),
|
171
|
+
"sample_size": cls._get_vae_sample_size(model, rbln_config),
|
172
|
+
"batch_size": batch_size,
|
173
|
+
}
|
174
|
+
|
175
|
+
@classmethod
|
176
|
+
def get_default_rbln_config_controlnet(cls, model: torch.nn.Module, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
177
|
+
# configuration for controlnet
|
178
|
+
unet_batch_size = cls._get_unet_batch_size(model, rbln_config)
|
179
|
+
text_model_hidden_size = model.text_encoder_2.config.hidden_size if hasattr(model, "text_encoder_2") else None
|
180
|
+
return {
|
181
|
+
**cls._get_default_config(model, rbln_config),
|
182
|
+
"max_seq_len": model.text_encoder.config.max_position_embeddings,
|
183
|
+
"vae_sample_size": cls._get_vae_sample_size(model, rbln_config),
|
184
|
+
"unet_sample_size": cls._get_unet_sample_size(model, rbln_config),
|
185
|
+
"batch_size": unet_batch_size,
|
186
|
+
"text_model_hidden_size": text_model_hidden_size,
|
187
|
+
}
|
188
|
+
|
189
|
+
@classmethod
|
190
|
+
def get_default_rbln_config(
|
191
|
+
cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
|
192
|
+
) -> Dict[str, Any]:
|
193
|
+
# Returns the default configuration based on submodule name
|
194
|
+
config_method = f"get_default_rbln_config_{submodule_name}"
|
195
|
+
if hasattr(cls, config_method):
|
196
|
+
return getattr(cls, config_method)(model, rbln_config)
|
197
|
+
raise ValueError(f"Unknown submodule: {submodule_name}")
|
198
|
+
|
199
|
+
@staticmethod
|
200
|
+
def _maybe_apply_and_fuse_lora(
|
201
|
+
model: torch.nn.Module,
|
202
|
+
lora_ids: Optional[Union[str, List[str]]] = None,
|
203
|
+
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
204
|
+
lora_scales: Optional[Union[float, List[float]]] = None,
|
205
|
+
) -> torch.nn.Module:
|
206
|
+
lora_ids = [lora_ids] if isinstance(lora_ids, str) else lora_ids
|
207
|
+
lora_weights_names = [lora_weights_names] if isinstance(lora_weights_names, str) else lora_weights_names
|
208
|
+
lora_scales = [lora_scales] if isinstance(lora_scales, float) else lora_scales
|
209
|
+
|
210
|
+
# adapt lora weight into pipeline before compilation
|
211
|
+
if lora_ids and lora_weights_names:
|
212
|
+
if len(lora_ids) == 1:
|
213
|
+
if len(lora_ids) != len(lora_weights_names):
|
214
|
+
raise ValueError(
|
215
|
+
f"You must define the same number of lora ids ({len(lora_ids)} and lora weights ({len(lora_weights_names)}))"
|
216
|
+
)
|
217
|
+
else:
|
218
|
+
model.load_lora_weights(lora_ids[0], weight_name=lora_weights_names[0])
|
219
|
+
model.fuse_lora(lora_scale=lora_scales[0] if lora_scales else 1.0)
|
220
|
+
elif len(lora_ids) > 1:
|
221
|
+
if not len(lora_ids) == len(lora_weights_names):
|
222
|
+
raise ValueError(
|
223
|
+
f"If you fuse {len(lora_ids)} lora models, but you must define the same number for lora weights and adapters."
|
224
|
+
)
|
225
|
+
|
226
|
+
adapter_names = [f"adapter_{i}" for i in range(len(lora_ids))]
|
227
|
+
|
228
|
+
for lora_id, lora_weight, adapter_name in zip(lora_ids, lora_weights_names, adapter_names):
|
229
|
+
model.load_lora_weights(lora_id, weight_name=lora_weight, adapter_name=adapter_name)
|
230
|
+
|
231
|
+
if lora_scales:
|
232
|
+
model.set_adapters(adapter_names, adapter_weights=lora_scales)
|
233
|
+
|
234
|
+
model.fuse_lora()
|
235
|
+
return model
|
236
|
+
|
237
|
+
@classmethod
|
238
|
+
@use_rbln_config
|
239
|
+
def from_pretrained(
|
240
|
+
cls,
|
241
|
+
model_id: str,
|
242
|
+
*,
|
243
|
+
export: bool = False,
|
244
|
+
model_save_dir: Optional[PathLike] = None,
|
245
|
+
rbln_config: Dict[str, Any] = {},
|
246
|
+
lora_ids: Optional[Union[str, List[str]]] = None,
|
247
|
+
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
248
|
+
lora_scales: Optional[Union[float, List[float]]] = None,
|
249
|
+
**kwargs,
|
250
|
+
) -> RBLNModel:
|
251
|
+
if export:
|
252
|
+
# keep submodules if user passed any of them.
|
253
|
+
passed_submodules = {
|
254
|
+
name: kwargs.pop(name) for name in cls._submodules if isinstance(kwargs.get(name), RBLNModel)
|
255
|
+
}
|
256
|
+
|
257
|
+
else:
|
258
|
+
# raise error if any of submodules are torch module.
|
259
|
+
for name in cls._submodules:
|
260
|
+
if isinstance(kwargs.get(name), torch.nn.Module):
|
261
|
+
raise AssertionError(
|
262
|
+
f"{name} is not compiled torch module. If you want to compile, set `export=True`."
|
263
|
+
)
|
264
|
+
|
265
|
+
with ContextRblnConfig(
|
266
|
+
device=rbln_config.get("device"),
|
267
|
+
device_map=rbln_config.get("device_map"),
|
268
|
+
create_runtimes=rbln_config.get("create_runtimes"),
|
269
|
+
optimize_host_mem=rbln_config.get("optimize_host_memory"),
|
270
|
+
):
|
271
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
272
|
+
|
273
|
+
if not export:
|
274
|
+
return model
|
275
|
+
|
276
|
+
model = cls._maybe_apply_and_fuse_lora(
|
277
|
+
model,
|
278
|
+
lora_ids=lora_ids,
|
279
|
+
lora_weights_names=lora_weights_names,
|
280
|
+
lora_scales=lora_scales,
|
281
|
+
)
|
282
|
+
|
283
|
+
compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
|
284
|
+
return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
|
285
|
+
|
286
|
+
@classmethod
|
287
|
+
def _compile_submodules(
|
288
|
+
cls,
|
289
|
+
model: torch.nn.Module,
|
290
|
+
passed_submodules: Dict[str, RBLNModel],
|
291
|
+
model_save_dir: Optional[PathLike],
|
292
|
+
rbln_config: Dict[str, Any],
|
293
|
+
) -> Dict[str, RBLNModel]:
|
294
|
+
# Compile submodules based on rbln_config
|
295
|
+
compiled_submodules = {}
|
296
|
+
|
297
|
+
# FIXME : Currently, optimum-rbln for transformer does not use base rbln config.
|
298
|
+
base_rbln_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
299
|
+
for submodule_name in cls._submodules:
|
300
|
+
submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
|
301
|
+
submodule_rbln_config = cls.get_default_rbln_config(model, submodule_name, rbln_config)
|
302
|
+
submodule_rbln_config.update(base_rbln_config)
|
303
|
+
submodule_rbln_config.update(rbln_config.get(submodule_name, {}))
|
304
|
+
|
305
|
+
if submodule is None:
|
306
|
+
raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
|
307
|
+
elif isinstance(submodule, RBLNModel):
|
308
|
+
pass
|
309
|
+
elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
|
310
|
+
# In case of multicontrolnet
|
311
|
+
submodule = cls._compile_multicontrolnet(
|
312
|
+
controlnets=submodule,
|
313
|
+
model_save_dir=model_save_dir,
|
314
|
+
controlnet_rbln_config=submodule_rbln_config,
|
315
|
+
)
|
316
|
+
elif isinstance(submodule, torch.nn.Module):
|
317
|
+
submodule_cls: RBLNModel = getattr(
|
318
|
+
importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
|
319
|
+
)
|
320
|
+
submodule = submodule_cls.from_model(
|
321
|
+
model=submodule,
|
322
|
+
subfolder=submodule_name,
|
323
|
+
model_save_dir=model_save_dir,
|
324
|
+
rbln_config=submodule_rbln_config,
|
325
|
+
)
|
326
|
+
else:
|
327
|
+
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
328
|
+
|
329
|
+
compiled_submodules[submodule_name] = submodule
|
330
|
+
return compiled_submodules
|
331
|
+
|
332
|
+
@classmethod
|
333
|
+
def _compile_multicontrolnet(
|
334
|
+
cls,
|
335
|
+
controlnets: "MultiControlNetModel",
|
336
|
+
model_save_dir: Optional[PathLike],
|
337
|
+
controlnet_rbln_config: Dict[str, Any],
|
338
|
+
):
|
339
|
+
# Compile multiple ControlNet models for a MultiControlNet setup
|
340
|
+
from .diffusers.models.controlnet import RBLNControlNetModel
|
341
|
+
from .diffusers.pipelines.controlnet import RBLNMultiControlNetModel
|
342
|
+
|
343
|
+
compiled_controlnets = [
|
344
|
+
RBLNControlNetModel.from_model(
|
345
|
+
model=controlnet,
|
346
|
+
subfolder="controlnet" if i == 0 else f"controlnet_{i}",
|
347
|
+
model_save_dir=model_save_dir,
|
348
|
+
rbln_config=controlnet_rbln_config,
|
349
|
+
)
|
350
|
+
for i, controlnet in enumerate(controlnets.nets)
|
351
|
+
]
|
352
|
+
return RBLNMultiControlNetModel(compiled_controlnets, config=controlnets.nets[0].config)
|
353
|
+
|
354
|
+
@classmethod
|
355
|
+
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
356
|
+
# Construct finalize pipe setup with compiled submodules and configurations
|
357
|
+
|
358
|
+
if model_save_dir is not None:
|
359
|
+
# To skip saving original pytorch modules
|
360
|
+
for submodule_name in cls._submodules:
|
361
|
+
delattr(model, submodule_name)
|
362
|
+
|
363
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
364
|
+
# So config must be saved again, later.
|
365
|
+
model.save_pretrained(model_save_dir)
|
366
|
+
# FIXME: Here, model touches its submodules such as model.unet,
|
367
|
+
# Causing warning messeages.
|
368
|
+
|
369
|
+
update_dict = {}
|
370
|
+
for submodule_name in cls._submodules:
|
371
|
+
# replace submodule
|
372
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
373
|
+
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
374
|
+
|
375
|
+
# Update config to be able to load from model directory.
|
376
|
+
#
|
377
|
+
# e.g)
|
378
|
+
# update_dict = {
|
379
|
+
# "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
380
|
+
# "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
381
|
+
# "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
382
|
+
# }
|
383
|
+
model.register_to_config(**update_dict)
|
384
|
+
|
385
|
+
if model_save_dir:
|
386
|
+
# overwrite to replace incorrect config
|
387
|
+
model.save_config(model_save_dir)
|
388
|
+
|
389
|
+
if rbln_config.get("optimize_host_memory") is False:
|
390
|
+
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
391
|
+
model.compiled_models = []
|
392
|
+
for name in cls._submodules:
|
393
|
+
submodule = getattr(model, name)
|
394
|
+
model.compiled_models.extend(submodule.compiled_models)
|
395
|
+
|
396
|
+
return model
|
397
|
+
|
398
|
+
@remove_compile_time_kwargs
|
399
|
+
def __call__(self, *args, **kwargs):
|
400
|
+
return super().__call__(*args, **kwargs)
|
@@ -42,18 +42,23 @@ _import_structure = {
|
|
42
42
|
"RBLNAutoModelForSequenceClassification",
|
43
43
|
"RBLNAutoModelForSpeechSeq2Seq",
|
44
44
|
"RBLNAutoModelForVision2Seq",
|
45
|
+
"RBLNBartForConditionalGeneration",
|
45
46
|
"RBLNBartModel",
|
46
47
|
"RBLNBertModel",
|
47
48
|
"RBLNCLIPTextModel",
|
48
49
|
"RBLNCLIPTextModelWithProjection",
|
49
50
|
"RBLNCLIPVisionModel",
|
50
51
|
"RBLNDPTForDepthEstimation",
|
52
|
+
"RBLNExaoneForCausalLM",
|
51
53
|
"RBLNGemmaForCausalLM",
|
52
54
|
"RBLNGPT2LMHeadModel",
|
55
|
+
"RBLNQwen2ForCausalLM",
|
53
56
|
"RBLNWav2Vec2ForCTC",
|
54
57
|
"RBLNWhisperForConditionalGeneration",
|
55
58
|
"RBLNLlamaForCausalLM",
|
56
59
|
"RBLNPhiForCausalLM",
|
60
|
+
"RBLNT5EncoderModel",
|
61
|
+
"RBLNT5ForConditionalGeneration",
|
57
62
|
"RBLNLlavaNextForConditionalGeneration",
|
58
63
|
"RBLNMidmLMHeadModel",
|
59
64
|
"RBLNXLMRobertaModel",
|
@@ -77,12 +82,14 @@ if TYPE_CHECKING:
|
|
77
82
|
RBLNAutoModelForSequenceClassification,
|
78
83
|
RBLNAutoModelForSpeechSeq2Seq,
|
79
84
|
RBLNAutoModelForVision2Seq,
|
85
|
+
RBLNBartForConditionalGeneration,
|
80
86
|
RBLNBartModel,
|
81
87
|
RBLNBertModel,
|
82
88
|
RBLNCLIPTextModel,
|
83
89
|
RBLNCLIPTextModelWithProjection,
|
84
90
|
RBLNCLIPVisionModel,
|
85
91
|
RBLNDPTForDepthEstimation,
|
92
|
+
RBLNExaoneForCausalLM,
|
86
93
|
RBLNGemmaForCausalLM,
|
87
94
|
RBLNGPT2LMHeadModel,
|
88
95
|
RBLNLlamaForCausalLM,
|
@@ -90,6 +97,9 @@ if TYPE_CHECKING:
|
|
90
97
|
RBLNMidmLMHeadModel,
|
91
98
|
RBLNMistralForCausalLM,
|
92
99
|
RBLNPhiForCausalLM,
|
100
|
+
RBLNQwen2ForCausalLM,
|
101
|
+
RBLNT5EncoderModel,
|
102
|
+
RBLNT5ForConditionalGeneration,
|
93
103
|
RBLNWav2Vec2ForCTC,
|
94
104
|
RBLNWhisperForConditionalGeneration,
|
95
105
|
RBLNXLMRobertaModel,
|
@@ -12,9 +12,11 @@ class RebelDynamicCache(DynamicCache):
|
|
12
12
|
`[batch_size, num_heads, seq_len, head_dim]`.
|
13
13
|
"""
|
14
14
|
|
15
|
-
def __init__(self,
|
15
|
+
def __init__(self, position_ids) -> None:
|
16
16
|
super().__init__()
|
17
|
-
|
17
|
+
# batch, _ = position_ids.shape
|
18
|
+
# current_steps = [position_ids[b][0] for b in range(batch)]
|
19
|
+
self.current_steps = position_ids[:, 0]
|
18
20
|
|
19
21
|
def assign(
|
20
22
|
self,
|
@@ -58,13 +60,7 @@ class RebelDynamicCache(DynamicCache):
|
|
58
60
|
@classmethod
|
59
61
|
def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
|
60
62
|
"""Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
|
61
|
-
|
62
|
-
batch, _ = position_ids.shape
|
63
|
-
current_steps = [position_ids[b][0] for b in range(batch)]
|
64
|
-
|
65
|
-
assert len(current_steps) == batch
|
66
|
-
cache = cls(current_steps)
|
67
|
-
|
63
|
+
cache = cls(position_ids)
|
68
64
|
for layer_idx in range(num_hidden_layer):
|
69
65
|
key_states = past_key_values[layer_idx * 2]
|
70
66
|
value_states = past_key_values[layer_idx * 2 + 1]
|