optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -22,12 +22,12 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING,
|
25
|
+
from typing import TYPE_CHECKING, Optional, Union
|
26
26
|
|
27
27
|
import torch
|
28
|
-
from transformers import PretrainedConfig
|
28
|
+
from transformers import PretrainedConfig
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
@@ -38,38 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNXLMRobertaModel(RBLNModel):
|
41
|
-
original_model_class = XLMRobertaModel
|
42
|
-
original_config_class = XLMRobertaConfig
|
43
|
-
|
44
|
-
@classmethod
|
45
|
-
def get_pytorch_model(
|
46
|
-
cls,
|
47
|
-
model_id: str,
|
48
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
49
|
-
revision: Optional[str] = None,
|
50
|
-
force_download: bool = False,
|
51
|
-
cache_dir: Optional[str] = None,
|
52
|
-
subfolder: str = "",
|
53
|
-
local_files_only: bool = False,
|
54
|
-
trust_remote_code: bool = False,
|
55
|
-
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
56
|
-
**kwargs,
|
57
|
-
) -> "PreTrainedModel":
|
58
|
-
model: "PreTrainedModel" = super().get_pytorch_model(
|
59
|
-
model_id=model_id,
|
60
|
-
use_auth_token=use_auth_token,
|
61
|
-
revision=revision,
|
62
|
-
force_download=force_download,
|
63
|
-
cache_dir=cache_dir,
|
64
|
-
subfolder=subfolder,
|
65
|
-
local_files_only=local_files_only,
|
66
|
-
trust_remote_code=trust_remote_code,
|
67
|
-
rbln_kwargs=rbln_kwargs,
|
68
|
-
library_name="transformers",
|
69
|
-
)
|
70
|
-
|
71
|
-
return model
|
72
|
-
|
73
41
|
@classmethod
|
74
42
|
def _get_rbln_config(
|
75
43
|
cls,
|
@@ -21,13 +21,95 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
|
25
|
-
|
24
|
+
import functools
|
25
|
+
import glob
|
26
|
+
import os
|
27
|
+
from typing import Any, Callable, Dict, Optional
|
26
28
|
|
27
29
|
import torch
|
30
|
+
from safetensors.torch import load_file
|
28
31
|
from torch.nn import Linear, Parameter
|
29
32
|
from torch.nn import functional as F
|
30
33
|
|
34
|
+
from ...utils.logging import get_logger
|
35
|
+
|
36
|
+
|
37
|
+
logger = get_logger()
|
38
|
+
|
39
|
+
SUPPORTED_QUANTIZATIONS: Dict[str, list[str]] = {
|
40
|
+
"rbln": ["w4a16"],
|
41
|
+
}
|
42
|
+
|
43
|
+
|
44
|
+
class QuantizationManager:
|
45
|
+
# The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
|
46
|
+
# It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
|
47
|
+
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
48
|
+
|
49
|
+
@staticmethod
|
50
|
+
def _raise_invalid_config_error(
|
51
|
+
key: str, value: str, valid_values: list[str], context: Optional[str] = None
|
52
|
+
) -> None:
|
53
|
+
context_info = f" for {context}" if context else ""
|
54
|
+
valid_values_str = ", ".join(valid_values)
|
55
|
+
raise ValueError(f"Invalid {key}: {value}{context_info}. " f"Supported values are: {valid_values_str}")
|
56
|
+
|
57
|
+
@staticmethod
|
58
|
+
def validate_quantization_config(quantize_config: Optional[dict]) -> Optional[dict]:
|
59
|
+
if not quantize_config:
|
60
|
+
return None
|
61
|
+
|
62
|
+
q_format = quantize_config.get("format")
|
63
|
+
q_precision = quantize_config.get("precision")
|
64
|
+
|
65
|
+
if q_format not in SUPPORTED_QUANTIZATIONS:
|
66
|
+
QuantizationManager._raise_invalid_config_error(
|
67
|
+
"quantization format", q_format, list(SUPPORTED_QUANTIZATIONS.keys())
|
68
|
+
)
|
69
|
+
|
70
|
+
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
71
|
+
QuantizationManager._raise_invalid_config_error(
|
72
|
+
"precision", q_precision, SUPPORTED_QUANTIZATIONS[q_format], q_format
|
73
|
+
)
|
74
|
+
|
75
|
+
return quantize_config
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def _set_env_var(cls, name: str, value: str) -> None:
|
79
|
+
os.environ[name] = value
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def _unset_env_var(cls, name: str) -> None:
|
83
|
+
os.environ.pop(name, None)
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def set_quantization_env(cls, quantize_config: Optional[dict]) -> Optional[str]:
|
87
|
+
quantize_config = cls.validate_quantization_config(quantize_config)
|
88
|
+
if quantize_config:
|
89
|
+
q_precision: str = quantize_config["precision"]
|
90
|
+
quant_bits = q_precision.split("w")[1].split("a")[0]
|
91
|
+
cls._set_env_var(cls.RBLN_QUANT_BITS_ENV, quant_bits)
|
92
|
+
return cls.RBLN_QUANT_BITS_ENV
|
93
|
+
return None
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def reset_quantization_env(cls, env_var_name: Optional[str]) -> None:
|
97
|
+
if env_var_name:
|
98
|
+
cls._unset_env_var(env_var_name)
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def with_quantization_env(cls, func: Callable) -> Callable:
|
102
|
+
@functools.wraps(func)
|
103
|
+
def wrapper(*args, **kwargs):
|
104
|
+
quantize_config = kwargs.get("quantize_config")
|
105
|
+
quantize_env_var = cls.set_quantization_env(quantize_config)
|
106
|
+
try:
|
107
|
+
return func(*args, **kwargs)
|
108
|
+
finally:
|
109
|
+
cls.reset_quantization_env(quantize_env_var)
|
110
|
+
|
111
|
+
return wrapper
|
112
|
+
|
31
113
|
|
32
114
|
# Constants
|
33
115
|
QUANTIZED_WEIGHTS = {
|
@@ -41,7 +123,15 @@ QUANTIZED_WEIGHTS = {
|
|
41
123
|
}
|
42
124
|
|
43
125
|
|
44
|
-
def
|
126
|
+
def prepare_model_for_quantization(model: torch.nn.Module, model_id: str, n_layer: Optional[int] = None) -> None:
|
127
|
+
"""
|
128
|
+
Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
|
129
|
+
"""
|
130
|
+
update_layers_to_quantize(model)
|
131
|
+
load_weights(model, model_id, n_layer)
|
132
|
+
|
133
|
+
|
134
|
+
def update_layers_to_quantize(module: torch.nn.Module) -> None:
|
45
135
|
"""
|
46
136
|
Updates specified linear layers to quantized (qlinear) layers in the given module.
|
47
137
|
"""
|
@@ -54,7 +144,33 @@ def update_layers_to_quantized(module: torch.nn.Module) -> None:
|
|
54
144
|
processed_layers.append(name)
|
55
145
|
|
56
146
|
if processed_layers:
|
57
|
-
|
147
|
+
logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
|
148
|
+
|
149
|
+
|
150
|
+
def load_weights(model, model_id, n_layer=None):
|
151
|
+
"""
|
152
|
+
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
153
|
+
"""
|
154
|
+
|
155
|
+
model_params = dict(model.named_parameters(recurse=True))
|
156
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
157
|
+
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
158
|
+
|
159
|
+
target_layers = list(range(n_layer)) if n_layer is not None else None
|
160
|
+
|
161
|
+
for safetensor_file in safetensor_files:
|
162
|
+
file_data = load_file(safetensor_file)
|
163
|
+
for key, value in file_data.items():
|
164
|
+
if target_layers is not None:
|
165
|
+
parts = key.split(".")
|
166
|
+
|
167
|
+
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
168
|
+
continue
|
169
|
+
|
170
|
+
if key in model_params:
|
171
|
+
model_params[key].data.copy_(value)
|
172
|
+
elif key in model_buffers:
|
173
|
+
model_buffers[key].data.copy_(value)
|
58
174
|
|
59
175
|
|
60
176
|
def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
|
@@ -1,3 +1,27 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
1
25
|
from functools import wraps
|
2
26
|
|
3
27
|
from .logging import get_logger
|
@@ -12,8 +36,8 @@ def remove_compile_time_kwargs(func):
|
|
12
36
|
|
13
37
|
For RBLN-optimized pipelines, several parameters must be determined during compilation
|
14
38
|
and cannot be modified during inference. This decorator:
|
15
|
-
1. Removes and warns about
|
16
|
-
2. Removes and warns about
|
39
|
+
1. Removes and warns about image dimension parameters (height, width)
|
40
|
+
2. Removes and warns about LoRA scale in cross_attention_kwargs
|
17
41
|
|
18
42
|
Args:
|
19
43
|
func: The pipeline's __call__ method to be wrapped
|
@@ -21,15 +45,31 @@ def remove_compile_time_kwargs(func):
|
|
21
45
|
|
22
46
|
@wraps(func)
|
23
47
|
def wrapper(self, *args, **kwargs):
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
)
|
31
|
-
|
32
|
-
|
48
|
+
check_params = {"height", "width"}
|
49
|
+
params = inspect.signature(self.original_class.__call__).parameters
|
50
|
+
|
51
|
+
# If height and width exist in the base pipeline's __call__ method arguments
|
52
|
+
# Otherwise, if there is no height or width of kwargs, it is filled based on the compiled size.
|
53
|
+
if check_params.issubset(params):
|
54
|
+
compiled_image_size = self.get_compiled_image_size()
|
55
|
+
if compiled_image_size is not None:
|
56
|
+
height_exists = "height" in kwargs and kwargs["height"] is not None
|
57
|
+
width_exists = "width" in kwargs and kwargs["width"] is not None
|
58
|
+
if height_exists or width_exists:
|
59
|
+
if not (
|
60
|
+
kwargs.get("height", None) == compiled_image_size[0]
|
61
|
+
and kwargs.get("width", None) == compiled_image_size[1]
|
62
|
+
):
|
63
|
+
logger.warning(
|
64
|
+
"Image dimension parameters (`height`, `width`) will be ignored during inference. "
|
65
|
+
"Image dimensions (%s, %s) must be specified during model compilation using from_pretrained(), (%s, %s).",
|
66
|
+
str(kwargs.get("height", None)),
|
67
|
+
str(kwargs.get("width", None)),
|
68
|
+
str(compiled_image_size[0]),
|
69
|
+
str(compiled_image_size[1]),
|
70
|
+
)
|
71
|
+
kwargs["height"] = compiled_image_size[0]
|
72
|
+
kwargs["width"] = compiled_image_size[1]
|
33
73
|
|
34
74
|
if "cross_attention_kwargs" in kwargs:
|
35
75
|
cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import os
|
25
|
+
from pathlib import Path
|
26
|
+
from typing import List, Optional, Union
|
27
|
+
|
28
|
+
from huggingface_hub import HfApi, HfFolder, hf_hub_download
|
29
|
+
|
30
|
+
|
31
|
+
class PushToHubMixin:
|
32
|
+
def push_to_hub(
|
33
|
+
self,
|
34
|
+
save_directory: str,
|
35
|
+
repository_id: str,
|
36
|
+
private: Optional[bool] = None,
|
37
|
+
use_auth_token: Union[bool, str] = True,
|
38
|
+
) -> str:
|
39
|
+
huggingface_token = _get_huggingface_token(use_auth_token)
|
40
|
+
api = HfApi()
|
41
|
+
|
42
|
+
api.create_repo(
|
43
|
+
token=huggingface_token,
|
44
|
+
repo_id=repository_id,
|
45
|
+
exist_ok=True,
|
46
|
+
private=private,
|
47
|
+
)
|
48
|
+
for path, subdirs, files in os.walk(save_directory):
|
49
|
+
for name in files:
|
50
|
+
local_file_path = os.path.join(path, name)
|
51
|
+
_, hub_file_path = os.path.split(local_file_path)
|
52
|
+
# FIXME: when huggingface_hub fixes the return of upload_file
|
53
|
+
try:
|
54
|
+
api.upload_file(
|
55
|
+
token=huggingface_token,
|
56
|
+
repo_id=f"{repository_id}",
|
57
|
+
path_or_fileobj=os.path.join(os.getcwd(), local_file_path),
|
58
|
+
path_in_repo=hub_file_path,
|
59
|
+
)
|
60
|
+
except KeyError:
|
61
|
+
pass
|
62
|
+
except NameError:
|
63
|
+
pass
|
64
|
+
|
65
|
+
|
66
|
+
def pull_compiled_model_from_hub(
|
67
|
+
model_id: Union[str, Path],
|
68
|
+
subfolder: str,
|
69
|
+
use_auth_token: Optional[Union[bool, str]],
|
70
|
+
revision: Optional[str],
|
71
|
+
cache_dir: Optional[str],
|
72
|
+
force_download: bool,
|
73
|
+
local_files_only: bool,
|
74
|
+
) -> Path:
|
75
|
+
"""Pull model files from the Hugging Face Hub."""
|
76
|
+
huggingface_token = _get_huggingface_token(use_auth_token)
|
77
|
+
repo_files = list(
|
78
|
+
map(
|
79
|
+
Path,
|
80
|
+
HfApi().list_repo_files(model_id, revision=revision, token=huggingface_token),
|
81
|
+
)
|
82
|
+
)
|
83
|
+
|
84
|
+
pattern_rbln = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
|
85
|
+
rbln_files = [p for p in repo_files if p.match(pattern_rbln)]
|
86
|
+
|
87
|
+
pattern_config = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
|
88
|
+
rbln_config_filenames = [p for p in repo_files if p.match(pattern_config)]
|
89
|
+
|
90
|
+
validate_files(rbln_files, rbln_config_filenames, f"repository {model_id}")
|
91
|
+
|
92
|
+
filenames = [str(path) for path in repo_files]
|
93
|
+
|
94
|
+
for filename in filenames:
|
95
|
+
rbln_config_cache_path = hf_hub_download(
|
96
|
+
repo_id=model_id,
|
97
|
+
filename=filename,
|
98
|
+
subfolder=subfolder,
|
99
|
+
use_auth_token=use_auth_token,
|
100
|
+
revision=revision,
|
101
|
+
cache_dir=cache_dir,
|
102
|
+
force_download=force_download,
|
103
|
+
local_files_only=local_files_only,
|
104
|
+
)
|
105
|
+
|
106
|
+
return Path(rbln_config_cache_path).parent
|
107
|
+
|
108
|
+
|
109
|
+
def validate_files(
|
110
|
+
files: List[Path],
|
111
|
+
config_files: List[Path],
|
112
|
+
location: str,
|
113
|
+
):
|
114
|
+
"""Validate the presence and count of required files."""
|
115
|
+
if len(files) == 0:
|
116
|
+
raise FileNotFoundError(f"Could not find any rbln model file in {location}")
|
117
|
+
|
118
|
+
if len(config_files) == 0:
|
119
|
+
raise FileNotFoundError(f"Could not find `rbln_config.json` file in {location}")
|
120
|
+
|
121
|
+
if len(config_files) > 1:
|
122
|
+
raise FileExistsError(f"Multiple rbln_config.json files found in {location}. This is not expected.")
|
123
|
+
|
124
|
+
|
125
|
+
def _get_huggingface_token(use_auth_token: Union[bool, str]) -> str:
|
126
|
+
if isinstance(use_auth_token, str):
|
127
|
+
return use_auth_token
|
128
|
+
elif use_auth_token:
|
129
|
+
return HfFolder.get_token()
|
130
|
+
else:
|
131
|
+
raise ValueError("`use_auth_token` must be provided to interact with the Hugging Face Hub.")
|
@@ -37,11 +37,32 @@ class VersionCompat:
|
|
37
37
|
|
38
38
|
|
39
39
|
RBLN_VERSION_COMPATS = {
|
40
|
+
"0.2.0": [
|
41
|
+
VersionCompat(
|
42
|
+
package_name="rebel-compiler",
|
43
|
+
min_version="0.7.1",
|
44
|
+
max_version="0.7.2",
|
45
|
+
),
|
46
|
+
],
|
47
|
+
"0.1.15": [
|
48
|
+
VersionCompat(
|
49
|
+
package_name="rebel-compiler",
|
50
|
+
min_version="0.6.2",
|
51
|
+
max_version="0.6.3",
|
52
|
+
),
|
53
|
+
],
|
54
|
+
"0.1.14": [
|
55
|
+
VersionCompat(
|
56
|
+
package_name="rebel-compiler",
|
57
|
+
min_version="0.6.2",
|
58
|
+
max_version="0.6.3",
|
59
|
+
),
|
60
|
+
],
|
40
61
|
"0.1.13": [
|
41
62
|
VersionCompat(
|
42
63
|
package_name="rebel-compiler",
|
43
64
|
min_version="0.6.0",
|
44
|
-
max_version="0.6.
|
65
|
+
max_version="0.6.2",
|
45
66
|
),
|
46
67
|
],
|
47
68
|
"0.1.12": [
|
optimum/rbln/utils/logging.py
CHANGED
@@ -1,3 +1,40 @@
|
|
1
|
+
# Copyright 2020 Optuna, Hugging Face
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Copyright 2024 Rebellions Inc.
|
16
|
+
|
17
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
18
|
+
# you may not use this file except in compliance with the License.
|
19
|
+
# You may obtain a copy of the License at:
|
20
|
+
|
21
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
22
|
+
|
23
|
+
# Unless required by applicable law or agreed to in writing, software
|
24
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
25
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
26
|
+
# See the License for the specific language governing permissions and
|
27
|
+
# limitations under the License.
|
28
|
+
|
29
|
+
# Portions of this software are licensed under the Apache License,
|
30
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
31
|
+
# additional information regarding copyright ownership.
|
32
|
+
|
33
|
+
# All other portions of this software, including proprietary code,
|
34
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
35
|
+
# copied, modified, or distributed without prior written permission
|
36
|
+
# from Rebellions Inc.
|
37
|
+
|
1
38
|
"""
|
2
39
|
Logging utilities.
|
3
40
|
Modified from `transformers.utils.logging.py`
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
# Prefix used for RBLN model class names
|
25
|
+
RBLN_PREFIX = "RBLN"
|
26
|
+
|
27
|
+
|
28
|
+
def convert_hf_to_rbln_model_name(hf_model_name: str):
|
29
|
+
"""
|
30
|
+
Convert Hugging Face model name to RBLN model name.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
hf_model_name (str): The Hugging Face model name.
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
str: The corresponding RBLN model name.
|
37
|
+
"""
|
38
|
+
return RBLN_PREFIX + hf_model_name
|
39
|
+
|
40
|
+
|
41
|
+
def convert_rbln_to_hf_model_name(rbln_model_name: str):
|
42
|
+
"""
|
43
|
+
Convert RBLN model name to Hugging Face model name.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
rbln_model_name (str): The RBLN model name.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
str: The corresponding Hugging Face model name.
|
50
|
+
"""
|
51
|
+
|
52
|
+
return rbln_model_name.removeprefix(RBLN_PREFIX)
|
@@ -35,15 +35,15 @@ class RBLNPytorchRuntime:
|
|
35
35
|
self.runtime = runtime
|
36
36
|
for key, value in kwargs.items():
|
37
37
|
setattr(self, key, value)
|
38
|
-
for mandatory_member in
|
38
|
+
for mandatory_member in self.mandatory_members:
|
39
39
|
if mandatory_member not in kwargs:
|
40
|
-
raise AttributeError(f"`{mandatory_member}` should be assigned to {__class__.__name__} objects.")
|
40
|
+
raise AttributeError(f"`{mandatory_member}` should be assigned to {self.__class__.__name__} objects.")
|
41
41
|
|
42
42
|
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
43
43
|
return self.forward(*args, **kwds)
|
44
44
|
|
45
45
|
def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
|
46
|
-
# filtering
|
46
|
+
# filtering useless args or kwarg such as None.
|
47
47
|
args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
|
48
48
|
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
|
49
49
|
output = self.runtime(*args, **kwargs)
|
@@ -76,17 +76,21 @@ class UnavailableRuntime:
|
|
76
76
|
class ContextRblnConfig:
|
77
77
|
_local = threading.local()
|
78
78
|
|
79
|
-
def __init__(
|
79
|
+
def __init__(
|
80
|
+
self, device=None, device_map=None, create_runtimes=None, optimize_host_mem=None, activate_profiler=None
|
81
|
+
):
|
80
82
|
self.device = device
|
81
83
|
self.device_map = device_map
|
82
84
|
self.create_runtimes = create_runtimes
|
83
85
|
self.optimize_host_mem = optimize_host_mem
|
86
|
+
self.activate_profiler = activate_profiler
|
84
87
|
|
85
88
|
def __enter__(self):
|
86
89
|
self._local.device = self.device
|
87
90
|
self._local.device_map = self.device_map
|
88
91
|
self._local.create_runtimes = self.create_runtimes
|
89
92
|
self._local.optimize_host_memory = self.optimize_host_mem
|
93
|
+
self._local.activate_profiler = self.activate_profiler
|
90
94
|
return self
|
91
95
|
|
92
96
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
@@ -94,6 +98,7 @@ class ContextRblnConfig:
|
|
94
98
|
self._local.device_map = None
|
95
99
|
self._local.create_runtimes = None
|
96
100
|
self._local.optimize_host_memory = None
|
101
|
+
self._local.activate_profiler = None
|
97
102
|
|
98
103
|
@classmethod
|
99
104
|
def get_current_context(cls):
|
@@ -102,4 +107,5 @@ class ContextRblnConfig:
|
|
102
107
|
"device_map": getattr(cls._local, "device_map", None),
|
103
108
|
"create_runtimes": getattr(cls._local, "create_runtimes", None),
|
104
109
|
"optimize_host_memory": getattr(cls._local, "optimize_host_memory", None),
|
110
|
+
"activate_profiler": getattr(cls._local, "activate_profiler", None),
|
105
111
|
}
|
optimum/rbln/utils/save_utils.py
CHANGED
@@ -1,3 +1,16 @@
|
|
1
|
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
1
14
|
# Copyright 2024 Rebellions Inc.
|
2
15
|
|
3
16
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -21,6 +34,10 @@
|
|
21
34
|
# copied, modified, or distributed without prior written permission
|
22
35
|
# from Rebellions Inc.
|
23
36
|
|
37
|
+
"""
|
38
|
+
Refer to huggingface/optimum/blob/4fdeea77d71e79451ba53e0c1f9d8f37e9704268/optimum/utils/save_utils.py
|
39
|
+
"""
|
40
|
+
|
24
41
|
import logging
|
25
42
|
from pathlib import Path
|
26
43
|
from typing import List, Union
|