optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. optimum/rbln/__init__.py +41 -38
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +26 -2
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
  5. optimum/rbln/diffusers/models/__init__.py +36 -3
  6. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  7. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
  9. optimum/rbln/diffusers/models/controlnet.py +54 -14
  10. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  12. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  13. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
  14. optimum/rbln/diffusers/pipelines/__init__.py +23 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  32. optimum/rbln/modeling.py +238 -0
  33. optimum/rbln/modeling_base.py +186 -760
  34. optimum/rbln/modeling_config.py +31 -7
  35. optimum/rbln/ops/__init__.py +26 -0
  36. optimum/rbln/ops/attn.py +221 -0
  37. optimum/rbln/ops/flash_attn.py +70 -0
  38. optimum/rbln/ops/kv_cache_update.py +69 -0
  39. optimum/rbln/transformers/__init__.py +20 -2
  40. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  41. optimum/rbln/transformers/modeling_generic.py +385 -0
  42. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  43. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  44. optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
  45. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  46. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  47. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
  48. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  49. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  50. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
  52. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
  53. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  54. optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
  55. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  56. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  57. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  58. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  59. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  60. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
  61. optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
  62. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
  63. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  64. optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
  65. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  66. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  68. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  69. optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  71. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  72. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  73. optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
  74. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  75. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  76. optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
  77. optimum/rbln/utils/decorator_utils.py +51 -11
  78. optimum/rbln/utils/hub.py +131 -0
  79. optimum/rbln/utils/import_utils.py +22 -1
  80. optimum/rbln/utils/logging.py +37 -0
  81. optimum/rbln/utils/model_utils.py +52 -0
  82. optimum/rbln/utils/runtime_utils.py +10 -4
  83. optimum/rbln/utils/save_utils.py +17 -0
  84. optimum/rbln/utils/submodule.py +137 -0
  85. optimum_rbln-0.2.0.dist-info/METADATA +117 -0
  86. optimum_rbln-0.2.0.dist-info/RECORD +114 -0
  87. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
  88. optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
  89. optimum/rbln/transformers/cache_utils.py +0 -107
  90. optimum/rbln/transformers/generation/streamers.py +0 -139
  91. optimum/rbln/transformers/generation/utils.py +0 -397
  92. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  93. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  94. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  95. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  96. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  97. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  98. optimum/rbln/utils/context.py +0 -58
  99. optimum/rbln/utils/timer_utils.py +0 -43
  100. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  101. optimum_rbln-0.1.13.dist-info/RECORD +0 -107
  102. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  103. optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -22,12 +22,12 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Optional, Union
26
26
 
27
27
  import torch
28
- from transformers import PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
28
+ from transformers import PretrainedConfig
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -38,38 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNXLMRobertaModel(RBLNModel):
41
- original_model_class = XLMRobertaModel
42
- original_config_class = XLMRobertaConfig
43
-
44
- @classmethod
45
- def get_pytorch_model(
46
- cls,
47
- model_id: str,
48
- use_auth_token: Optional[Union[bool, str]] = None,
49
- revision: Optional[str] = None,
50
- force_download: bool = False,
51
- cache_dir: Optional[str] = None,
52
- subfolder: str = "",
53
- local_files_only: bool = False,
54
- trust_remote_code: bool = False,
55
- rbln_kwargs: Optional[Dict[str, Any]] = None,
56
- **kwargs,
57
- ) -> "PreTrainedModel":
58
- model: "PreTrainedModel" = super().get_pytorch_model(
59
- model_id=model_id,
60
- use_auth_token=use_auth_token,
61
- revision=revision,
62
- force_download=force_download,
63
- cache_dir=cache_dir,
64
- subfolder=subfolder,
65
- local_files_only=local_files_only,
66
- trust_remote_code=trust_remote_code,
67
- rbln_kwargs=rbln_kwargs,
68
- library_name="transformers",
69
- )
70
-
71
- return model
72
-
73
41
  @classmethod
74
42
  def _get_rbln_config(
75
43
  cls,
@@ -21,13 +21,95 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
- from typing import Any
24
+ import functools
25
+ import glob
26
+ import os
27
+ from typing import Any, Callable, Dict, Optional
26
28
 
27
29
  import torch
30
+ from safetensors.torch import load_file
28
31
  from torch.nn import Linear, Parameter
29
32
  from torch.nn import functional as F
30
33
 
34
+ from ...utils.logging import get_logger
35
+
36
+
37
+ logger = get_logger()
38
+
39
+ SUPPORTED_QUANTIZATIONS: Dict[str, list[str]] = {
40
+ "rbln": ["w4a16"],
41
+ }
42
+
43
+
44
+ class QuantizationManager:
45
+ # The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
46
+ # It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
47
+ RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
48
+
49
+ @staticmethod
50
+ def _raise_invalid_config_error(
51
+ key: str, value: str, valid_values: list[str], context: Optional[str] = None
52
+ ) -> None:
53
+ context_info = f" for {context}" if context else ""
54
+ valid_values_str = ", ".join(valid_values)
55
+ raise ValueError(f"Invalid {key}: {value}{context_info}. " f"Supported values are: {valid_values_str}")
56
+
57
+ @staticmethod
58
+ def validate_quantization_config(quantize_config: Optional[dict]) -> Optional[dict]:
59
+ if not quantize_config:
60
+ return None
61
+
62
+ q_format = quantize_config.get("format")
63
+ q_precision = quantize_config.get("precision")
64
+
65
+ if q_format not in SUPPORTED_QUANTIZATIONS:
66
+ QuantizationManager._raise_invalid_config_error(
67
+ "quantization format", q_format, list(SUPPORTED_QUANTIZATIONS.keys())
68
+ )
69
+
70
+ if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
71
+ QuantizationManager._raise_invalid_config_error(
72
+ "precision", q_precision, SUPPORTED_QUANTIZATIONS[q_format], q_format
73
+ )
74
+
75
+ return quantize_config
76
+
77
+ @classmethod
78
+ def _set_env_var(cls, name: str, value: str) -> None:
79
+ os.environ[name] = value
80
+
81
+ @classmethod
82
+ def _unset_env_var(cls, name: str) -> None:
83
+ os.environ.pop(name, None)
84
+
85
+ @classmethod
86
+ def set_quantization_env(cls, quantize_config: Optional[dict]) -> Optional[str]:
87
+ quantize_config = cls.validate_quantization_config(quantize_config)
88
+ if quantize_config:
89
+ q_precision: str = quantize_config["precision"]
90
+ quant_bits = q_precision.split("w")[1].split("a")[0]
91
+ cls._set_env_var(cls.RBLN_QUANT_BITS_ENV, quant_bits)
92
+ return cls.RBLN_QUANT_BITS_ENV
93
+ return None
94
+
95
+ @classmethod
96
+ def reset_quantization_env(cls, env_var_name: Optional[str]) -> None:
97
+ if env_var_name:
98
+ cls._unset_env_var(env_var_name)
99
+
100
+ @classmethod
101
+ def with_quantization_env(cls, func: Callable) -> Callable:
102
+ @functools.wraps(func)
103
+ def wrapper(*args, **kwargs):
104
+ quantize_config = kwargs.get("quantize_config")
105
+ quantize_env_var = cls.set_quantization_env(quantize_config)
106
+ try:
107
+ return func(*args, **kwargs)
108
+ finally:
109
+ cls.reset_quantization_env(quantize_env_var)
110
+
111
+ return wrapper
112
+
31
113
 
32
114
  # Constants
33
115
  QUANTIZED_WEIGHTS = {
@@ -41,7 +123,15 @@ QUANTIZED_WEIGHTS = {
41
123
  }
42
124
 
43
125
 
44
- def update_layers_to_quantized(module: torch.nn.Module) -> None:
126
+ def prepare_model_for_quantization(model: torch.nn.Module, model_id: str, n_layer: Optional[int] = None) -> None:
127
+ """
128
+ Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
129
+ """
130
+ update_layers_to_quantize(model)
131
+ load_weights(model, model_id, n_layer)
132
+
133
+
134
+ def update_layers_to_quantize(module: torch.nn.Module) -> None:
45
135
  """
46
136
  Updates specified linear layers to quantized (qlinear) layers in the given module.
47
137
  """
@@ -54,7 +144,33 @@ def update_layers_to_quantized(module: torch.nn.Module) -> None:
54
144
  processed_layers.append(name)
55
145
 
56
146
  if processed_layers:
57
- print(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
147
+ logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
148
+
149
+
150
+ def load_weights(model, model_id, n_layer=None):
151
+ """
152
+ Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
153
+ """
154
+
155
+ model_params = dict(model.named_parameters(recurse=True))
156
+ model_buffers = dict(model.named_buffers(recurse=True))
157
+ safetensor_files = glob.glob(f"{model_id}/*.safetensors")
158
+
159
+ target_layers = list(range(n_layer)) if n_layer is not None else None
160
+
161
+ for safetensor_file in safetensor_files:
162
+ file_data = load_file(safetensor_file)
163
+ for key, value in file_data.items():
164
+ if target_layers is not None:
165
+ parts = key.split(".")
166
+
167
+ if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
168
+ continue
169
+
170
+ if key in model_params:
171
+ model_params[key].data.copy_(value)
172
+ elif key in model_buffers:
173
+ model_buffers[key].data.copy_(value)
58
174
 
59
175
 
60
176
  def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
@@ -1,3 +1,27 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
1
25
  from functools import wraps
2
26
 
3
27
  from .logging import get_logger
@@ -12,8 +36,8 @@ def remove_compile_time_kwargs(func):
12
36
 
13
37
  For RBLN-optimized pipelines, several parameters must be determined during compilation
14
38
  and cannot be modified during inference. This decorator:
15
- 1. Removes and warns about LoRA scale in cross_attention_kwargs
16
- 2. Removes and warns about image dimension parameters (height, width)
39
+ 1. Removes and warns about image dimension parameters (height, width)
40
+ 2. Removes and warns about LoRA scale in cross_attention_kwargs
17
41
 
18
42
  Args:
19
43
  func: The pipeline's __call__ method to be wrapped
@@ -21,15 +45,31 @@ def remove_compile_time_kwargs(func):
21
45
 
22
46
  @wraps(func)
23
47
  def wrapper(self, *args, **kwargs):
24
- height_exists = "height" in kwargs and kwargs["height"] is not None
25
- width_exists = "width" in kwargs and kwargs["width"] is not None
26
- if height_exists or width_exists:
27
- logger.warning(
28
- "Image dimension parameters (`height`, `width`) will be ignored during inference. "
29
- "Image dimensions must be specified during model compilation using from_pretrained()."
30
- )
31
- kwargs.pop("width", None)
32
- kwargs.pop("height", None)
48
+ check_params = {"height", "width"}
49
+ params = inspect.signature(self.original_class.__call__).parameters
50
+
51
+ # If height and width exist in the base pipeline's __call__ method arguments
52
+ # Otherwise, if there is no height or width of kwargs, it is filled based on the compiled size.
53
+ if check_params.issubset(params):
54
+ compiled_image_size = self.get_compiled_image_size()
55
+ if compiled_image_size is not None:
56
+ height_exists = "height" in kwargs and kwargs["height"] is not None
57
+ width_exists = "width" in kwargs and kwargs["width"] is not None
58
+ if height_exists or width_exists:
59
+ if not (
60
+ kwargs.get("height", None) == compiled_image_size[0]
61
+ and kwargs.get("width", None) == compiled_image_size[1]
62
+ ):
63
+ logger.warning(
64
+ "Image dimension parameters (`height`, `width`) will be ignored during inference. "
65
+ "Image dimensions (%s, %s) must be specified during model compilation using from_pretrained(), (%s, %s).",
66
+ str(kwargs.get("height", None)),
67
+ str(kwargs.get("width", None)),
68
+ str(compiled_image_size[0]),
69
+ str(compiled_image_size[1]),
70
+ )
71
+ kwargs["height"] = compiled_image_size[0]
72
+ kwargs["width"] = compiled_image_size[1]
33
73
 
34
74
  if "cross_attention_kwargs" in kwargs:
35
75
  cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
@@ -0,0 +1,131 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import os
25
+ from pathlib import Path
26
+ from typing import List, Optional, Union
27
+
28
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download
29
+
30
+
31
+ class PushToHubMixin:
32
+ def push_to_hub(
33
+ self,
34
+ save_directory: str,
35
+ repository_id: str,
36
+ private: Optional[bool] = None,
37
+ use_auth_token: Union[bool, str] = True,
38
+ ) -> str:
39
+ huggingface_token = _get_huggingface_token(use_auth_token)
40
+ api = HfApi()
41
+
42
+ api.create_repo(
43
+ token=huggingface_token,
44
+ repo_id=repository_id,
45
+ exist_ok=True,
46
+ private=private,
47
+ )
48
+ for path, subdirs, files in os.walk(save_directory):
49
+ for name in files:
50
+ local_file_path = os.path.join(path, name)
51
+ _, hub_file_path = os.path.split(local_file_path)
52
+ # FIXME: when huggingface_hub fixes the return of upload_file
53
+ try:
54
+ api.upload_file(
55
+ token=huggingface_token,
56
+ repo_id=f"{repository_id}",
57
+ path_or_fileobj=os.path.join(os.getcwd(), local_file_path),
58
+ path_in_repo=hub_file_path,
59
+ )
60
+ except KeyError:
61
+ pass
62
+ except NameError:
63
+ pass
64
+
65
+
66
+ def pull_compiled_model_from_hub(
67
+ model_id: Union[str, Path],
68
+ subfolder: str,
69
+ use_auth_token: Optional[Union[bool, str]],
70
+ revision: Optional[str],
71
+ cache_dir: Optional[str],
72
+ force_download: bool,
73
+ local_files_only: bool,
74
+ ) -> Path:
75
+ """Pull model files from the Hugging Face Hub."""
76
+ huggingface_token = _get_huggingface_token(use_auth_token)
77
+ repo_files = list(
78
+ map(
79
+ Path,
80
+ HfApi().list_repo_files(model_id, revision=revision, token=huggingface_token),
81
+ )
82
+ )
83
+
84
+ pattern_rbln = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
85
+ rbln_files = [p for p in repo_files if p.match(pattern_rbln)]
86
+
87
+ pattern_config = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
88
+ rbln_config_filenames = [p for p in repo_files if p.match(pattern_config)]
89
+
90
+ validate_files(rbln_files, rbln_config_filenames, f"repository {model_id}")
91
+
92
+ filenames = [str(path) for path in repo_files]
93
+
94
+ for filename in filenames:
95
+ rbln_config_cache_path = hf_hub_download(
96
+ repo_id=model_id,
97
+ filename=filename,
98
+ subfolder=subfolder,
99
+ use_auth_token=use_auth_token,
100
+ revision=revision,
101
+ cache_dir=cache_dir,
102
+ force_download=force_download,
103
+ local_files_only=local_files_only,
104
+ )
105
+
106
+ return Path(rbln_config_cache_path).parent
107
+
108
+
109
+ def validate_files(
110
+ files: List[Path],
111
+ config_files: List[Path],
112
+ location: str,
113
+ ):
114
+ """Validate the presence and count of required files."""
115
+ if len(files) == 0:
116
+ raise FileNotFoundError(f"Could not find any rbln model file in {location}")
117
+
118
+ if len(config_files) == 0:
119
+ raise FileNotFoundError(f"Could not find `rbln_config.json` file in {location}")
120
+
121
+ if len(config_files) > 1:
122
+ raise FileExistsError(f"Multiple rbln_config.json files found in {location}. This is not expected.")
123
+
124
+
125
+ def _get_huggingface_token(use_auth_token: Union[bool, str]) -> str:
126
+ if isinstance(use_auth_token, str):
127
+ return use_auth_token
128
+ elif use_auth_token:
129
+ return HfFolder.get_token()
130
+ else:
131
+ raise ValueError("`use_auth_token` must be provided to interact with the Hugging Face Hub.")
@@ -37,11 +37,32 @@ class VersionCompat:
37
37
 
38
38
 
39
39
  RBLN_VERSION_COMPATS = {
40
+ "0.2.0": [
41
+ VersionCompat(
42
+ package_name="rebel-compiler",
43
+ min_version="0.7.1",
44
+ max_version="0.7.2",
45
+ ),
46
+ ],
47
+ "0.1.15": [
48
+ VersionCompat(
49
+ package_name="rebel-compiler",
50
+ min_version="0.6.2",
51
+ max_version="0.6.3",
52
+ ),
53
+ ],
54
+ "0.1.14": [
55
+ VersionCompat(
56
+ package_name="rebel-compiler",
57
+ min_version="0.6.2",
58
+ max_version="0.6.3",
59
+ ),
60
+ ],
40
61
  "0.1.13": [
41
62
  VersionCompat(
42
63
  package_name="rebel-compiler",
43
64
  min_version="0.6.0",
44
- max_version="0.6.1",
65
+ max_version="0.6.2",
45
66
  ),
46
67
  ],
47
68
  "0.1.12": [
@@ -1,3 +1,40 @@
1
+ # Copyright 2020 Optuna, Hugging Face
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2024 Rebellions Inc.
16
+
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at:
20
+
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ # Portions of this software are licensed under the Apache License,
30
+ # Version 2.0. See the NOTICE file distributed with this work for
31
+ # additional information regarding copyright ownership.
32
+
33
+ # All other portions of this software, including proprietary code,
34
+ # are the intellectual property of Rebellions Inc. and may not be
35
+ # copied, modified, or distributed without prior written permission
36
+ # from Rebellions Inc.
37
+
1
38
  """
2
39
  Logging utilities.
3
40
  Modified from `transformers.utils.logging.py`
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ # Prefix used for RBLN model class names
25
+ RBLN_PREFIX = "RBLN"
26
+
27
+
28
+ def convert_hf_to_rbln_model_name(hf_model_name: str):
29
+ """
30
+ Convert Hugging Face model name to RBLN model name.
31
+
32
+ Args:
33
+ hf_model_name (str): The Hugging Face model name.
34
+
35
+ Returns:
36
+ str: The corresponding RBLN model name.
37
+ """
38
+ return RBLN_PREFIX + hf_model_name
39
+
40
+
41
+ def convert_rbln_to_hf_model_name(rbln_model_name: str):
42
+ """
43
+ Convert RBLN model name to Hugging Face model name.
44
+
45
+ Args:
46
+ rbln_model_name (str): The RBLN model name.
47
+
48
+ Returns:
49
+ str: The corresponding Hugging Face model name.
50
+ """
51
+
52
+ return rbln_model_name.removeprefix(RBLN_PREFIX)
@@ -35,15 +35,15 @@ class RBLNPytorchRuntime:
35
35
  self.runtime = runtime
36
36
  for key, value in kwargs.items():
37
37
  setattr(self, key, value)
38
- for mandatory_member in __class__.mandatory_members:
38
+ for mandatory_member in self.mandatory_members:
39
39
  if mandatory_member not in kwargs:
40
- raise AttributeError(f"`{mandatory_member}` should be assigned to {__class__.__name__} objects.")
40
+ raise AttributeError(f"`{mandatory_member}` should be assigned to {self.__class__.__name__} objects.")
41
41
 
42
42
  def __call__(self, *args: Any, **kwds: Any) -> Any:
43
43
  return self.forward(*args, **kwds)
44
44
 
45
45
  def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
46
- # filtering uselss args or kwarg such as None.
46
+ # filtering useless args or kwarg such as None.
47
47
  args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
48
48
  kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
49
49
  output = self.runtime(*args, **kwargs)
@@ -76,17 +76,21 @@ class UnavailableRuntime:
76
76
  class ContextRblnConfig:
77
77
  _local = threading.local()
78
78
 
79
- def __init__(self, device=None, device_map=None, create_runtimes=None, optimize_host_mem=None):
79
+ def __init__(
80
+ self, device=None, device_map=None, create_runtimes=None, optimize_host_mem=None, activate_profiler=None
81
+ ):
80
82
  self.device = device
81
83
  self.device_map = device_map
82
84
  self.create_runtimes = create_runtimes
83
85
  self.optimize_host_mem = optimize_host_mem
86
+ self.activate_profiler = activate_profiler
84
87
 
85
88
  def __enter__(self):
86
89
  self._local.device = self.device
87
90
  self._local.device_map = self.device_map
88
91
  self._local.create_runtimes = self.create_runtimes
89
92
  self._local.optimize_host_memory = self.optimize_host_mem
93
+ self._local.activate_profiler = self.activate_profiler
90
94
  return self
91
95
 
92
96
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -94,6 +98,7 @@ class ContextRblnConfig:
94
98
  self._local.device_map = None
95
99
  self._local.create_runtimes = None
96
100
  self._local.optimize_host_memory = None
101
+ self._local.activate_profiler = None
97
102
 
98
103
  @classmethod
99
104
  def get_current_context(cls):
@@ -102,4 +107,5 @@ class ContextRblnConfig:
102
107
  "device_map": getattr(cls._local, "device_map", None),
103
108
  "create_runtimes": getattr(cls._local, "create_runtimes", None),
104
109
  "optimize_host_memory": getattr(cls._local, "optimize_host_memory", None),
110
+ "activate_profiler": getattr(cls._local, "activate_profiler", None),
105
111
  }
@@ -1,3 +1,16 @@
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  # Copyright 2024 Rebellions Inc.
2
15
 
3
16
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -21,6 +34,10 @@
21
34
  # copied, modified, or distributed without prior written permission
22
35
  # from Rebellions Inc.
23
36
 
37
+ """
38
+ Refer to huggingface/optimum/blob/4fdeea77d71e79451ba53e0c1f9d8f37e9704268/optimum/utils/save_utils.py
39
+ """
40
+
24
41
  import logging
25
42
  from pathlib import Path
26
43
  from typing import List, Union