optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +27 -13
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +85 -65
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
- optimum/rbln/diffusers/pipelines/__init__.py +60 -12
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +176 -763
- optimum/rbln/modeling_diffusers.py +329 -0
- optimum/rbln/transformers/__init__.py +2 -2
- optimum/rbln/transformers/cache_utils.py +5 -9
- optimum/rbln/transformers/modeling_rope_utils.py +283 -0
- optimum/rbln/transformers/models/__init__.py +80 -31
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
- optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
- optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
- optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/__init__.py +1 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
- optimum/rbln/utils/decorator_utils.py +59 -0
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +21 -0
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +5 -5
- optimum/rbln/utils/submodule.py +114 -0
- optimum/rbln/utils/timer_utils.py +2 -2
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- optimum_rbln-0.1.15.dist-info/RECORD +110 -0
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum_rbln-0.1.12.dist-info/METADATA +0 -119
- optimum_rbln-0.1.12.dist-info/RECORD +0 -103
- optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,329 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
import copy
|
24
|
+
import importlib
|
25
|
+
from os import PathLike
|
26
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
27
|
+
|
28
|
+
import torch
|
29
|
+
|
30
|
+
from .modeling import RBLNModel
|
31
|
+
from .modeling_config import RUNTIME_KEYWORDS, ContextRblnConfig, use_rbln_config
|
32
|
+
from .utils.decorator_utils import remove_compile_time_kwargs
|
33
|
+
|
34
|
+
|
35
|
+
if TYPE_CHECKING:
|
36
|
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
37
|
+
|
38
|
+
|
39
|
+
class RBLNDiffusionMixin:
|
40
|
+
"""
|
41
|
+
RBLNDiffusionMixin provides essential functionalities for compiling Stable Diffusion pipeline components to run on RBLN NPUs.
|
42
|
+
This mixin class serves as a base for implementing RBLN-compatible Stable Diffusion pipelines. It contains shared logic for
|
43
|
+
handling the core components of Stable Diffusion.
|
44
|
+
|
45
|
+
To use this mixin:
|
46
|
+
|
47
|
+
1. Create a new pipeline class that inherits from both this mixin and the original StableDiffusionPipeline.
|
48
|
+
2. Define the required _submodules class variable listing the components to be compiled.
|
49
|
+
3. If needed, implement get_default_rbln_config for custom configuration of submodules.
|
50
|
+
|
51
|
+
Example:
|
52
|
+
```python
|
53
|
+
class RBLNStableDiffusionPipeline(RBLNDiffusionMixin, StableDiffusionPipeline):
|
54
|
+
_submodules = ["text_encoder", "unet", "vae"]
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def get_default_rbln_config(cls, model, submodule_name, rbln_config):
|
58
|
+
# Configuration for other submodules...
|
59
|
+
pass
|
60
|
+
```
|
61
|
+
|
62
|
+
Class Variables:
|
63
|
+
_submodules: List of submodule names that should be compiled (typically ["text_encoder", "unet", "vae"])
|
64
|
+
|
65
|
+
Methods:
|
66
|
+
from_pretrained: Creates and optionally compiles a model from a pretrained checkpoint
|
67
|
+
|
68
|
+
Notes:
|
69
|
+
- When `export=True`, all compatible submodules will be compiled for NPU inference
|
70
|
+
- The compilation config can be customized per submodule by including submodule names
|
71
|
+
as keys in rbln_config
|
72
|
+
"""
|
73
|
+
|
74
|
+
_submodules = []
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
@property
|
78
|
+
def img2img_pipeline(cls):
|
79
|
+
return "Img2Img" in cls.__name__
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
@property
|
83
|
+
def inpaint_pipeline(cls):
|
84
|
+
return "Inpaint" in cls.__name__
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def get_submodule_rbln_config(
|
88
|
+
cls, model: torch.nn.Module, submodule_name: str, rbln_config: Dict[str, Any]
|
89
|
+
) -> Dict[str, Any]:
|
90
|
+
submodule = getattr(model, submodule_name)
|
91
|
+
submodule_class_name = submodule.__class__.__name__
|
92
|
+
|
93
|
+
if submodule_class_name == "MultiControlNetModel":
|
94
|
+
submodule_class_name = "ControlNetModel"
|
95
|
+
|
96
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), f"RBLN{submodule_class_name}")
|
97
|
+
|
98
|
+
submodule_config = rbln_config.get(submodule_name, {})
|
99
|
+
submodule_config = copy.deepcopy(submodule_config)
|
100
|
+
|
101
|
+
pipe_global_config = {k: v for k, v in rbln_config.items() if k not in cls._submodules}
|
102
|
+
|
103
|
+
submodule_config.update({k: v for k, v in pipe_global_config.items() if k not in submodule_config})
|
104
|
+
submodule_config.update(
|
105
|
+
{
|
106
|
+
"img2img_pipeline": cls.img2img_pipeline,
|
107
|
+
"inpaint_pipeline": cls.inpaint_pipeline,
|
108
|
+
}
|
109
|
+
)
|
110
|
+
submodule_config = submodule_cls.update_rbln_config_using_pipe(model, submodule_config)
|
111
|
+
return submodule_config
|
112
|
+
|
113
|
+
@staticmethod
|
114
|
+
def _maybe_apply_and_fuse_lora(
|
115
|
+
model: torch.nn.Module,
|
116
|
+
lora_ids: Optional[Union[str, List[str]]] = None,
|
117
|
+
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
118
|
+
lora_scales: Optional[Union[float, List[float]]] = None,
|
119
|
+
) -> torch.nn.Module:
|
120
|
+
lora_ids = [lora_ids] if isinstance(lora_ids, str) else lora_ids
|
121
|
+
lora_weights_names = [lora_weights_names] if isinstance(lora_weights_names, str) else lora_weights_names
|
122
|
+
lora_scales = [lora_scales] if isinstance(lora_scales, float) else lora_scales
|
123
|
+
|
124
|
+
# adapt lora weight into pipeline before compilation
|
125
|
+
if lora_ids and lora_weights_names:
|
126
|
+
if len(lora_ids) == 1:
|
127
|
+
if len(lora_ids) != len(lora_weights_names):
|
128
|
+
raise ValueError(
|
129
|
+
f"You must define the same number of lora ids ({len(lora_ids)} and lora weights ({len(lora_weights_names)}))"
|
130
|
+
)
|
131
|
+
else:
|
132
|
+
model.load_lora_weights(lora_ids[0], weight_name=lora_weights_names[0])
|
133
|
+
model.fuse_lora(lora_scale=lora_scales[0] if lora_scales else 1.0)
|
134
|
+
elif len(lora_ids) > 1:
|
135
|
+
if not len(lora_ids) == len(lora_weights_names):
|
136
|
+
raise ValueError(
|
137
|
+
f"If you fuse {len(lora_ids)} lora models, but you must define the same number for lora weights and adapters."
|
138
|
+
)
|
139
|
+
|
140
|
+
adapter_names = [f"adapter_{i}" for i in range(len(lora_ids))]
|
141
|
+
|
142
|
+
for lora_id, lora_weight, adapter_name in zip(lora_ids, lora_weights_names, adapter_names):
|
143
|
+
model.load_lora_weights(lora_id, weight_name=lora_weight, adapter_name=adapter_name)
|
144
|
+
|
145
|
+
if lora_scales:
|
146
|
+
model.set_adapters(adapter_names, adapter_weights=lora_scales)
|
147
|
+
|
148
|
+
model.fuse_lora()
|
149
|
+
return model
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
@use_rbln_config
|
153
|
+
def from_pretrained(
|
154
|
+
cls,
|
155
|
+
model_id: str,
|
156
|
+
*,
|
157
|
+
export: bool = False,
|
158
|
+
model_save_dir: Optional[PathLike] = None,
|
159
|
+
rbln_config: Dict[str, Any] = {},
|
160
|
+
lora_ids: Optional[Union[str, List[str]]] = None,
|
161
|
+
lora_weights_names: Optional[Union[str, List[str]]] = None,
|
162
|
+
lora_scales: Optional[Union[float, List[float]]] = None,
|
163
|
+
**kwargs,
|
164
|
+
) -> RBLNModel:
|
165
|
+
if export:
|
166
|
+
# keep submodules if user passed any of them.
|
167
|
+
passed_submodules = {
|
168
|
+
name: kwargs.pop(name) for name in cls._submodules if isinstance(kwargs.get(name), RBLNModel)
|
169
|
+
}
|
170
|
+
|
171
|
+
else:
|
172
|
+
# raise error if any of submodules are torch module.
|
173
|
+
model_index_config = None
|
174
|
+
for submodule_name in cls._submodules:
|
175
|
+
if isinstance(kwargs.get(submodule_name), torch.nn.Module):
|
176
|
+
raise AssertionError(
|
177
|
+
f"{submodule_name} is not compiled torch module. If you want to compile, set `export=True`."
|
178
|
+
)
|
179
|
+
|
180
|
+
# Load submodule outside if runtime kwargs(e.g. device) is specified.
|
181
|
+
if submodule_config := rbln_config.get(submodule_name):
|
182
|
+
if any(kwd in submodule_config for kwd in RUNTIME_KEYWORDS):
|
183
|
+
if model_index_config is None:
|
184
|
+
model_index_config = cls.load_config(pretrained_model_name_or_path=model_id)
|
185
|
+
|
186
|
+
module_name, class_name = model_index_config[submodule_name]
|
187
|
+
if module_name != "optimum.rbln":
|
188
|
+
raise ValueError(
|
189
|
+
f"Invalid module_name '{module_name}' found in model_index.json for "
|
190
|
+
f"submodule '{submodule_name}'. "
|
191
|
+
"Expected 'optimum.rbln'. Please check the model_index.json configuration."
|
192
|
+
)
|
193
|
+
submodule_cls: RBLNModel = getattr(importlib.import_module("optimum.rbln"), class_name)
|
194
|
+
submodule = submodule_cls.from_pretrained(
|
195
|
+
model_id, export=False, subfolder=submodule_name, rbln_config=submodule_config
|
196
|
+
)
|
197
|
+
kwargs[submodule_name] = submodule
|
198
|
+
|
199
|
+
with ContextRblnConfig(
|
200
|
+
device=rbln_config.get("device"),
|
201
|
+
device_map=rbln_config.get("device_map"),
|
202
|
+
create_runtimes=rbln_config.get("create_runtimes"),
|
203
|
+
optimize_host_mem=rbln_config.get("optimize_host_memory"),
|
204
|
+
):
|
205
|
+
model = super().from_pretrained(pretrained_model_name_or_path=model_id, **kwargs)
|
206
|
+
|
207
|
+
if not export:
|
208
|
+
return model
|
209
|
+
|
210
|
+
model = cls._maybe_apply_and_fuse_lora(
|
211
|
+
model,
|
212
|
+
lora_ids=lora_ids,
|
213
|
+
lora_weights_names=lora_weights_names,
|
214
|
+
lora_scales=lora_scales,
|
215
|
+
)
|
216
|
+
|
217
|
+
compiled_submodules = cls._compile_submodules(model, passed_submodules, model_save_dir, rbln_config)
|
218
|
+
return cls._construct_pipe(model, compiled_submodules, model_save_dir, rbln_config)
|
219
|
+
|
220
|
+
@classmethod
|
221
|
+
def _compile_submodules(
|
222
|
+
cls,
|
223
|
+
model: torch.nn.Module,
|
224
|
+
passed_submodules: Dict[str, RBLNModel],
|
225
|
+
model_save_dir: Optional[PathLike],
|
226
|
+
rbln_config: Dict[str, Any],
|
227
|
+
) -> Dict[str, RBLNModel]:
|
228
|
+
compiled_submodules = {}
|
229
|
+
|
230
|
+
for submodule_name in cls._submodules:
|
231
|
+
submodule = passed_submodules.get(submodule_name) or getattr(model, submodule_name, None)
|
232
|
+
submodule_rbln_config = cls.get_submodule_rbln_config(model, submodule_name, rbln_config)
|
233
|
+
|
234
|
+
if submodule is None:
|
235
|
+
raise ValueError(f"submodule ({submodule_name}) cannot be accessed since it is not provided.")
|
236
|
+
elif isinstance(submodule, RBLNModel):
|
237
|
+
pass
|
238
|
+
elif submodule_name == "controlnet" and hasattr(submodule, "nets"):
|
239
|
+
# In case of multicontrolnet
|
240
|
+
submodule = cls._compile_multicontrolnet(
|
241
|
+
controlnets=submodule,
|
242
|
+
model_save_dir=model_save_dir,
|
243
|
+
controlnet_rbln_config=submodule_rbln_config,
|
244
|
+
)
|
245
|
+
elif isinstance(submodule, torch.nn.Module):
|
246
|
+
submodule_cls: RBLNModel = getattr(
|
247
|
+
importlib.import_module("optimum.rbln"), f"RBLN{submodule.__class__.__name__}"
|
248
|
+
)
|
249
|
+
submodule = submodule_cls.from_model(
|
250
|
+
model=submodule,
|
251
|
+
subfolder=submodule_name,
|
252
|
+
model_save_dir=model_save_dir,
|
253
|
+
rbln_config=submodule_rbln_config,
|
254
|
+
)
|
255
|
+
else:
|
256
|
+
raise ValueError(f"Unknown class of submodule({submodule_name}) : {submodule.__class__.__name__} ")
|
257
|
+
|
258
|
+
compiled_submodules[submodule_name] = submodule
|
259
|
+
return compiled_submodules
|
260
|
+
|
261
|
+
@classmethod
|
262
|
+
def _compile_multicontrolnet(
|
263
|
+
cls,
|
264
|
+
controlnets: "MultiControlNetModel",
|
265
|
+
model_save_dir: Optional[PathLike],
|
266
|
+
controlnet_rbln_config: Dict[str, Any],
|
267
|
+
):
|
268
|
+
# Compile multiple ControlNet models for a MultiControlNet setup
|
269
|
+
from .diffusers.models.controlnet import RBLNControlNetModel
|
270
|
+
from .diffusers.pipelines.controlnet import RBLNMultiControlNetModel
|
271
|
+
|
272
|
+
compiled_controlnets = [
|
273
|
+
RBLNControlNetModel.from_model(
|
274
|
+
model=controlnet,
|
275
|
+
subfolder="controlnet" if i == 0 else f"controlnet_{i}",
|
276
|
+
model_save_dir=model_save_dir,
|
277
|
+
rbln_config=controlnet_rbln_config,
|
278
|
+
)
|
279
|
+
for i, controlnet in enumerate(controlnets.nets)
|
280
|
+
]
|
281
|
+
return RBLNMultiControlNetModel(compiled_controlnets, config=controlnets.nets[0].config)
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def _construct_pipe(cls, model, submodules, model_save_dir, rbln_config):
|
285
|
+
# Construct finalize pipe setup with compiled submodules and configurations
|
286
|
+
|
287
|
+
if model_save_dir is not None:
|
288
|
+
# To skip saving original pytorch modules
|
289
|
+
for submodule_name in cls._submodules:
|
290
|
+
delattr(model, submodule_name)
|
291
|
+
|
292
|
+
# Direct calling of `save_pretrained` causes config.unet = (None, None).
|
293
|
+
# So config must be saved again, later.
|
294
|
+
model.save_pretrained(model_save_dir)
|
295
|
+
# FIXME: Here, model touches its submodules such as model.unet,
|
296
|
+
# Causing warning messeages.
|
297
|
+
|
298
|
+
update_dict = {}
|
299
|
+
for submodule_name in cls._submodules:
|
300
|
+
# replace submodule
|
301
|
+
setattr(model, submodule_name, submodules[submodule_name])
|
302
|
+
update_dict[submodule_name] = ("optimum.rbln", submodules[submodule_name].__class__.__name__)
|
303
|
+
|
304
|
+
# Update config to be able to load from model directory.
|
305
|
+
#
|
306
|
+
# e.g)
|
307
|
+
# update_dict = {
|
308
|
+
# "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
|
309
|
+
# "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
|
310
|
+
# "unet": ("optimum.rbln", "RBLNUNet2DConditionModel"),
|
311
|
+
# }
|
312
|
+
model.register_to_config(**update_dict)
|
313
|
+
|
314
|
+
if model_save_dir:
|
315
|
+
# overwrite to replace incorrect config
|
316
|
+
model.save_config(model_save_dir)
|
317
|
+
|
318
|
+
if rbln_config.get("optimize_host_memory") is False:
|
319
|
+
# Keep compiled_model objs to further analysis. -> TODO: remove soon...
|
320
|
+
model.compiled_models = []
|
321
|
+
for name in cls._submodules:
|
322
|
+
submodule = getattr(model, name)
|
323
|
+
model.compiled_models.extend(submodule.compiled_models)
|
324
|
+
|
325
|
+
return model
|
326
|
+
|
327
|
+
@remove_compile_time_kwargs
|
328
|
+
def __call__(self, *args, **kwargs):
|
329
|
+
return super().__call__(*args, **kwargs)
|
@@ -28,7 +28,6 @@ from transformers.utils import _LazyModule
|
|
28
28
|
|
29
29
|
_import_structure = {
|
30
30
|
"cache_utils": ["RebelDynamicCache"],
|
31
|
-
"generation": ["BatchTextIteratorStreamer"],
|
32
31
|
"models": [
|
33
32
|
"RBLNAutoModel",
|
34
33
|
"RBLNAutoModelForAudioClassification",
|
@@ -57,6 +56,7 @@ _import_structure = {
|
|
57
56
|
"RBLNWhisperForConditionalGeneration",
|
58
57
|
"RBLNLlamaForCausalLM",
|
59
58
|
"RBLNPhiForCausalLM",
|
59
|
+
"RBLNT5EncoderModel",
|
60
60
|
"RBLNT5ForConditionalGeneration",
|
61
61
|
"RBLNLlavaNextForConditionalGeneration",
|
62
62
|
"RBLNMidmLMHeadModel",
|
@@ -67,7 +67,6 @@ _import_structure = {
|
|
67
67
|
|
68
68
|
if TYPE_CHECKING:
|
69
69
|
from .cache_utils import RebelDynamicCache
|
70
|
-
from .generation import BatchTextIteratorStreamer
|
71
70
|
from .models import (
|
72
71
|
RBLNAutoModel,
|
73
72
|
RBLNAutoModelForAudioClassification,
|
@@ -97,6 +96,7 @@ if TYPE_CHECKING:
|
|
97
96
|
RBLNMistralForCausalLM,
|
98
97
|
RBLNPhiForCausalLM,
|
99
98
|
RBLNQwen2ForCausalLM,
|
99
|
+
RBLNT5EncoderModel,
|
100
100
|
RBLNT5ForConditionalGeneration,
|
101
101
|
RBLNWav2Vec2ForCTC,
|
102
102
|
RBLNWhisperForConditionalGeneration,
|
@@ -12,9 +12,11 @@ class RebelDynamicCache(DynamicCache):
|
|
12
12
|
`[batch_size, num_heads, seq_len, head_dim]`.
|
13
13
|
"""
|
14
14
|
|
15
|
-
def __init__(self,
|
15
|
+
def __init__(self, position_ids) -> None:
|
16
16
|
super().__init__()
|
17
|
-
|
17
|
+
# batch, _ = position_ids.shape
|
18
|
+
# current_steps = [position_ids[b][0] for b in range(batch)]
|
19
|
+
self.current_steps = position_ids[:, 0]
|
18
20
|
|
19
21
|
def assign(
|
20
22
|
self,
|
@@ -58,13 +60,7 @@ class RebelDynamicCache(DynamicCache):
|
|
58
60
|
@classmethod
|
59
61
|
def from_input_format(cls, position_ids, num_hidden_layer, *past_key_values) -> "DynamicCache":
|
60
62
|
"""Converts a cache in the rbln cache format (list of past_kv) into an equivalent `DynamicCache`."""
|
61
|
-
|
62
|
-
batch, _ = position_ids.shape
|
63
|
-
current_steps = [position_ids[b][0] for b in range(batch)]
|
64
|
-
|
65
|
-
assert len(current_steps) == batch
|
66
|
-
cache = cls(current_steps)
|
67
|
-
|
63
|
+
cache = cls(position_ids)
|
68
64
|
for layer_idx in range(num_hidden_layer):
|
69
65
|
key_states = past_key_values[layer_idx * 2]
|
70
66
|
value_states = past_key_values[layer_idx * 2 + 1]
|
@@ -0,0 +1,283 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Optional, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from transformers import PretrainedConfig
|
6
|
+
|
7
|
+
|
8
|
+
def _compute_default_rope_parameters(
|
9
|
+
config: Optional[PretrainedConfig] = None,
|
10
|
+
seq_len: Optional[int] = None,
|
11
|
+
) -> Tuple["torch.Tensor", float]:
|
12
|
+
"""
|
13
|
+
Computes the inverse frequencies according to the original RoPE implementation
|
14
|
+
Args:
|
15
|
+
config ([`~transformers.PretrainedConfig`]):
|
16
|
+
The model configuration.
|
17
|
+
seq_len (`int`, *optional*):
|
18
|
+
The current sequence length. Unused for this type of RoPE.
|
19
|
+
Returns:
|
20
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
21
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
22
|
+
"""
|
23
|
+
|
24
|
+
base = config.rope_theta
|
25
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
26
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
27
|
+
dim = int(head_dim * partial_rotary_factor)
|
28
|
+
|
29
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
30
|
+
|
31
|
+
# Compute the inverse frequencies
|
32
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
33
|
+
return inv_freq, attention_factor
|
34
|
+
|
35
|
+
|
36
|
+
def _compute_linear_scaling_rope_parameters(
|
37
|
+
config: Optional[PretrainedConfig] = None,
|
38
|
+
seq_len: Optional[int] = None,
|
39
|
+
) -> Tuple["torch.Tensor", float]:
|
40
|
+
"""
|
41
|
+
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
42
|
+
Args:
|
43
|
+
config ([`~transformers.PretrainedConfig`]):
|
44
|
+
The model configuration.
|
45
|
+
seq_len (`int`, *optional*):
|
46
|
+
The current sequence length. Unused for this type of RoPE.
|
47
|
+
Returns:
|
48
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
49
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
50
|
+
"""
|
51
|
+
|
52
|
+
factor = config.rope_scaling["factor"]
|
53
|
+
|
54
|
+
# Gets the default RoPE parameters
|
55
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
56
|
+
|
57
|
+
# Then applies linear scaling to the frequencies.
|
58
|
+
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
59
|
+
# applying scaling to the inverse frequencies is equivalent.
|
60
|
+
inv_freq /= factor
|
61
|
+
return inv_freq, attention_factor
|
62
|
+
|
63
|
+
|
64
|
+
def _compute_dynamic_ntk_parameters(
|
65
|
+
config: Optional[PretrainedConfig] = None,
|
66
|
+
seq_len: Optional[int] = None,
|
67
|
+
) -> Tuple["torch.Tensor", float]:
|
68
|
+
"""
|
69
|
+
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
70
|
+
Args:
|
71
|
+
config ([`~transformers.PretrainedConfig`]):
|
72
|
+
The model configuration.
|
73
|
+
seq_len (`int`, *optional*):
|
74
|
+
The current sequence length, used to update the dynamic RoPE at inference time.
|
75
|
+
Returns:
|
76
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
77
|
+
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
78
|
+
"""
|
79
|
+
|
80
|
+
base = config.rope_theta
|
81
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
82
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
83
|
+
dim = int(head_dim * partial_rotary_factor)
|
84
|
+
max_position_embeddings = config.max_position_embeddings
|
85
|
+
factor = config.rope_scaling["factor"]
|
86
|
+
|
87
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
88
|
+
|
89
|
+
# Process with chunk_size to reduce precesion error
|
90
|
+
chunk_size = 4096
|
91
|
+
chunks = (seq_len + chunk_size - 1) // chunk_size
|
92
|
+
|
93
|
+
inv_freq_list = []
|
94
|
+
for i in range(chunks):
|
95
|
+
start = i * chunk_size
|
96
|
+
end = min((i + 1) * chunk_size, seq_len)
|
97
|
+
|
98
|
+
seq_lens = torch.arange(start, end, dtype=torch.float32).view(-1, 1) + 1.0
|
99
|
+
seq_lens = torch.where(seq_lens > max_position_embeddings, seq_lens, max_position_embeddings)
|
100
|
+
|
101
|
+
# Compute the inverse frequencies for each chunk
|
102
|
+
scaled_base = base * ((factor * seq_lens / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
|
103
|
+
inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
|
104
|
+
|
105
|
+
inv_freq_list.append(inv_freq)
|
106
|
+
|
107
|
+
final_inv_freq = torch.cat(inv_freq_list, dim=0)
|
108
|
+
|
109
|
+
return final_inv_freq, attention_factor
|
110
|
+
|
111
|
+
|
112
|
+
def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] = None) -> Tuple["torch.Tensor", float]:
|
113
|
+
"""
|
114
|
+
Computes the inverse frequencies with NTK scaling. Please refer to the
|
115
|
+
[original paper](https://arxiv.org/abs/2309.00071)
|
116
|
+
Args:
|
117
|
+
config ([`~transformers.PretrainedConfig`]):
|
118
|
+
The model configuration.
|
119
|
+
seq_len (`int`, *optional*):
|
120
|
+
The current sequence length. Unused for this type of RoPE.
|
121
|
+
Returns:
|
122
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
123
|
+
post-processing scaling factor applied to the computed cos/sin.
|
124
|
+
"""
|
125
|
+
|
126
|
+
base = config.rope_theta
|
127
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
128
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
129
|
+
dim = int(head_dim * partial_rotary_factor)
|
130
|
+
max_position_embeddings = config.max_position_embeddings
|
131
|
+
factor = config.rope_scaling["factor"]
|
132
|
+
|
133
|
+
# Sets the attention factor as suggested in the paper
|
134
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
135
|
+
if attention_factor is None:
|
136
|
+
attention_factor = 0.1 * math.log(factor) + 1.0
|
137
|
+
|
138
|
+
# Optional config options
|
139
|
+
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
140
|
+
beta_fast = config.rope_scaling.get("beta_fast") or 32
|
141
|
+
beta_slow = config.rope_scaling.get("beta_slow") or 1
|
142
|
+
|
143
|
+
# Compute the inverse frequencies
|
144
|
+
def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
|
145
|
+
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
146
|
+
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
147
|
+
|
148
|
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
149
|
+
"""Find dimension range bounds based on rotations"""
|
150
|
+
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
151
|
+
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
|
152
|
+
return max(low, 0), min(high, dim - 1)
|
153
|
+
|
154
|
+
def linear_ramp_factor(min, max, dim):
|
155
|
+
if min == max:
|
156
|
+
max += 0.001 # Prevent singularity
|
157
|
+
|
158
|
+
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
159
|
+
ramp_func = torch.clamp(linear_func, 0, 1)
|
160
|
+
return ramp_func
|
161
|
+
|
162
|
+
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
163
|
+
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
164
|
+
pos_freqs = base ** (torch.arange(0, dim, 2).float() / dim)
|
165
|
+
inv_freq_extrapolation = 1.0 / pos_freqs
|
166
|
+
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
167
|
+
|
168
|
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
|
169
|
+
|
170
|
+
# Get n-dimensional rotational scaling corrected for extrapolation
|
171
|
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float()
|
172
|
+
inv_freq = (
|
173
|
+
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
174
|
+
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
175
|
+
)
|
176
|
+
|
177
|
+
return inv_freq, attention_factor
|
178
|
+
|
179
|
+
|
180
|
+
def _compute_longrope_parameters(
|
181
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
182
|
+
) -> Tuple["torch.Tensor", float]:
|
183
|
+
"""
|
184
|
+
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
185
|
+
[original implementation](https://github.com/microsoft/LongRoPE)
|
186
|
+
Args:
|
187
|
+
config ([`~transformers.PretrainedConfig`]):
|
188
|
+
The model configuration.
|
189
|
+
seq_len (`int`, *optional*):
|
190
|
+
The current sequence length. Unused for this type of RoPE.
|
191
|
+
Returns:
|
192
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
193
|
+
post-processing scaling factor applied to the computed cos/sin.
|
194
|
+
"""
|
195
|
+
|
196
|
+
base = config.rope_theta
|
197
|
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
198
|
+
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
199
|
+
dim = int(head_dim * partial_rotary_factor)
|
200
|
+
long_factor = config.rope_scaling["long_factor"]
|
201
|
+
short_factor = config.rope_scaling["short_factor"]
|
202
|
+
factor = config.rope_scaling.get("factor")
|
203
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
204
|
+
|
205
|
+
# NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
|
206
|
+
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
207
|
+
# values to compute the default attention scaling factor, instead of using `factor`.
|
208
|
+
if hasattr(config, "original_max_position_embeddings"):
|
209
|
+
max_position_embeddings = config.original_max_position_embeddings
|
210
|
+
expanded_max_position_embeddings = config.max_position_embeddings
|
211
|
+
factor = expanded_max_position_embeddings / max_position_embeddings
|
212
|
+
else:
|
213
|
+
max_position_embeddings = config.max_position_embeddings
|
214
|
+
expanded_max_position_embeddings = max_position_embeddings * factor
|
215
|
+
|
216
|
+
# Sets the attention factor as suggested in the paper
|
217
|
+
if attention_factor is None:
|
218
|
+
if factor <= 1.0:
|
219
|
+
attention_factor = 1.0
|
220
|
+
else:
|
221
|
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
|
222
|
+
|
223
|
+
# Compute the inverse frequencies -- scaled based on the target sequence length
|
224
|
+
if expanded_max_position_embeddings > max_position_embeddings:
|
225
|
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32)
|
226
|
+
else:
|
227
|
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32)
|
228
|
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
|
229
|
+
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
230
|
+
|
231
|
+
return inv_freq, attention_factor
|
232
|
+
|
233
|
+
|
234
|
+
def _compute_llama3_parameters(
|
235
|
+
config: PretrainedConfig, seq_len: Optional[int] = None
|
236
|
+
) -> Tuple["torch.Tensor", float]:
|
237
|
+
"""
|
238
|
+
Computes the inverse frequencies for llama 3.1.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
config ([`~transformers.PretrainedConfig`]):
|
242
|
+
The model configuration.
|
243
|
+
seq_len (`int`, *optional*):
|
244
|
+
The current sequence length. Unused for this type of RoPE.
|
245
|
+
Returns:
|
246
|
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
247
|
+
post-processing scaling factor applied to the computed cos/sin.
|
248
|
+
"""
|
249
|
+
# Gets the default RoPE parameters
|
250
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
251
|
+
|
252
|
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
253
|
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
254
|
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
255
|
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
256
|
+
|
257
|
+
low_freq_wavelen = old_context_len / low_freq_factor
|
258
|
+
high_freq_wavelen = old_context_len / high_freq_factor
|
259
|
+
|
260
|
+
wavelen = 2 * math.pi / inv_freq
|
261
|
+
# wavelen < high_freq_wavelen: do nothing
|
262
|
+
# wavelen > low_freq_wavelen: divide by factor
|
263
|
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
264
|
+
# otherwise: interpolate between the two, using a smooth factor
|
265
|
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
266
|
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
267
|
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
268
|
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
269
|
+
|
270
|
+
return inv_freq_llama, attention_factor
|
271
|
+
|
272
|
+
|
273
|
+
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
|
274
|
+
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
|
275
|
+
# parameterizations, as long as the callable has the same signature.
|
276
|
+
ROPE_INIT_FUNCTIONS = {
|
277
|
+
"default": _compute_default_rope_parameters,
|
278
|
+
"linear": _compute_linear_scaling_rope_parameters,
|
279
|
+
"dynamic": _compute_dynamic_ntk_parameters,
|
280
|
+
"yarn": _compute_yarn_parameters,
|
281
|
+
"longrope": _compute_longrope_parameters,
|
282
|
+
"llama3": _compute_llama3_parameters,
|
283
|
+
}
|