optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -22,12 +22,23 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import inspect
25
- from typing import TYPE_CHECKING, Any, Callable
26
-
27
- from transformers import T5ForConditionalGeneration
28
-
29
- from ....modeling_config import RBLNConfig
25
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import transformers
29
+ from transformers import (
30
+ AutoModelForTextEncoding,
31
+ PretrainedConfig,
32
+ T5EncoderModel,
33
+ T5ForConditionalGeneration,
34
+ )
35
+ from transformers.modeling_outputs import BaseModelOutput
36
+
37
+ from ....modeling import RBLNModel
38
+ from ....modeling_config import RBLNCompileConfig, RBLNConfig
39
+ from ....modeling_diffusers import RBLNDiffusionMixin
30
40
  from ....utils.logging import get_logger
41
+ from ....utils.runtime_utils import RBLNPytorchRuntime
31
42
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
32
43
  from .t5_architecture import T5Wrapper
33
44
 
@@ -35,7 +46,147 @@ from .t5_architecture import T5Wrapper
35
46
  logger = get_logger()
36
47
 
37
48
  if TYPE_CHECKING:
38
- from transformers import PreTrainedModel
49
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
50
+
51
+
52
+ class RBLNRuntimeModel(RBLNPytorchRuntime):
53
+ def forward(
54
+ self,
55
+ input_ids: torch.LongTensor,
56
+ attention_mask: torch.FloatTensor,
57
+ head_mask: torch.FloatTensor,
58
+ inputs_embeds: torch.FloatTensor,
59
+ **kwargs,
60
+ ):
61
+ return super().forward(
62
+ input_ids,
63
+ attention_mask,
64
+ head_mask,
65
+ inputs_embeds,
66
+ **kwargs,
67
+ )
68
+
69
+
70
+ class T5EncoderWrapper(torch.nn.Module):
71
+ def __init__(self, model: "T5EncoderModel") -> None:
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ def forward(self, *args, **kwargs):
76
+ kwargs.pop("return_dict", None)
77
+ return self.model(*args, **kwargs, return_dict=False)
78
+
79
+
80
+ class RBLNT5EncoderModel(RBLNModel):
81
+ auto_model_class = AutoModelForTextEncoding
82
+ rbln_model_input_names = ["input_ids", "attention_mask"]
83
+
84
+ def __post_init__(self, **kwargs):
85
+ self.model = RBLNRuntimeModel(runtime=self.model[0])
86
+
87
+ @classmethod
88
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
89
+ return T5EncoderWrapper(model)
90
+
91
+ @classmethod
92
+ def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
93
+ batch_size = rbln_config.get("batch_size", 1)
94
+ max_sequence_length = rbln_config.get("max_sequence_length", 256)
95
+ model_input_names = ["input_ids"]
96
+
97
+ rbln_config.update(
98
+ {
99
+ "batch_size": batch_size,
100
+ "max_seq_len": max_sequence_length,
101
+ "model_input_names": model_input_names,
102
+ }
103
+ )
104
+
105
+ return rbln_config
106
+
107
+ @classmethod
108
+ def _get_rbln_config(
109
+ cls,
110
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
111
+ model_config: Optional["PretrainedConfig"] = None,
112
+ rbln_kwargs: Dict[str, Any] = {},
113
+ ) -> RBLNConfig:
114
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
115
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
116
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
117
+
118
+ max_position_embeddings = getattr(model_config, "n_positions", None)
119
+
120
+ if rbln_max_seq_len is None:
121
+ rbln_max_seq_len = max_position_embeddings
122
+ if rbln_max_seq_len is None:
123
+ for tokenizer in preprocessors:
124
+ if hasattr(tokenizer, "model_max_length"):
125
+ rbln_max_seq_len = tokenizer.model_max_length
126
+ break
127
+ if rbln_max_seq_len is None:
128
+ raise ValueError("`rbln_max_seq_len` should be specified!")
129
+
130
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
131
+ raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
132
+
133
+ if rbln_model_input_names is None:
134
+ for tokenizer in preprocessors:
135
+ if hasattr(tokenizer, "model_input_names"):
136
+ rbln_model_input_names = tokenizer.model_input_names
137
+ break
138
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
139
+ rbln_model_input_names = cls.rbln_model_input_names
140
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
141
+ original_model_class = getattr(transformers, model_config.architectures[0])
142
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
143
+ raise ValueError(
144
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
145
+ f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(input_names_order)})"
146
+ )
147
+
148
+ if rbln_batch_size is None:
149
+ rbln_batch_size = 1
150
+
151
+ input_info = [
152
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
153
+ for model_input_name in rbln_model_input_names
154
+ ]
155
+
156
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
157
+
158
+ rbln_config = RBLNConfig(
159
+ rbln_cls=cls.__name__,
160
+ compile_cfgs=[rbln_compile_config],
161
+ rbln_kwargs=rbln_kwargs,
162
+ )
163
+
164
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
165
+ return rbln_config
166
+
167
+ def forward(
168
+ self,
169
+ input_ids: Optional[torch.LongTensor] = None,
170
+ attention_mask: Optional[torch.FloatTensor] = None,
171
+ head_mask: Optional[torch.FloatTensor] = None,
172
+ inputs_embeds: Optional[torch.FloatTensor] = None,
173
+ output_attentions: Optional[bool] = None,
174
+ output_hidden_states: Optional[bool] = None,
175
+ return_dict: Optional[bool] = None,
176
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
177
+ encoder_outputs = self.model(
178
+ input_ids=input_ids,
179
+ attention_mask=attention_mask,
180
+ inputs_embeds=inputs_embeds,
181
+ head_mask=head_mask,
182
+ output_attentions=output_attentions,
183
+ output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict,
185
+ )
186
+ if not return_dict:
187
+ return (encoder_outputs,)
188
+ else:
189
+ return BaseModelOutput(last_hidden_state=encoder_outputs)
39
190
 
40
191
 
41
192
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
@@ -28,7 +28,7 @@ import torch
28
28
  from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
29
29
  from transformers.modeling_outputs import CausalLMOutput
30
30
 
31
- from ....modeling_base import RBLNModel
31
+ from ....modeling import RBLNModel
32
32
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
33
33
 
34
34
 
@@ -36,7 +36,7 @@ from transformers import (
36
36
  )
37
37
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
38
38
 
39
- from ....modeling_base import RBLNModel
39
+ from ....modeling import RBLNModel
40
40
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
41
41
  from ....utils.runtime_utils import RBLNPytorchRuntime
42
42
  from .generation_whisper import RBLNWhisperGenerationMixin
@@ -102,7 +102,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
102
102
  class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
103
103
  """
104
104
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
105
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
105
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
106
106
 
107
107
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
108
108
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -22,12 +22,12 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
25
+ from typing import TYPE_CHECKING, Optional, Union
26
26
 
27
27
  import torch
28
- from transformers import PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
28
+ from transformers import PretrainedConfig
29
29
 
30
- from ....modeling_base import RBLNModel
30
+ from ....modeling import RBLNModel
31
31
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
32
32
 
33
33
 
@@ -38,38 +38,6 @@ if TYPE_CHECKING:
38
38
 
39
39
 
40
40
  class RBLNXLMRobertaModel(RBLNModel):
41
- original_model_class = XLMRobertaModel
42
- original_config_class = XLMRobertaConfig
43
-
44
- @classmethod
45
- def get_pytorch_model(
46
- cls,
47
- model_id: str,
48
- use_auth_token: Optional[Union[bool, str]] = None,
49
- revision: Optional[str] = None,
50
- force_download: bool = False,
51
- cache_dir: Optional[str] = None,
52
- subfolder: str = "",
53
- local_files_only: bool = False,
54
- trust_remote_code: bool = False,
55
- rbln_kwargs: Optional[Dict[str, Any]] = None,
56
- **kwargs,
57
- ) -> "PreTrainedModel":
58
- model: "PreTrainedModel" = super().get_pytorch_model(
59
- model_id=model_id,
60
- use_auth_token=use_auth_token,
61
- revision=revision,
62
- force_download=force_download,
63
- cache_dir=cache_dir,
64
- subfolder=subfolder,
65
- local_files_only=local_files_only,
66
- trust_remote_code=trust_remote_code,
67
- rbln_kwargs=rbln_kwargs,
68
- library_name="transformers",
69
- )
70
-
71
- return model
72
-
73
41
  @classmethod
74
42
  def _get_rbln_config(
75
43
  cls,
@@ -22,21 +22,117 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
 
25
- from typing import Any
25
+ import functools
26
+ import glob
27
+ import os
28
+ from typing import Any, Callable, Dict, Optional
26
29
 
27
30
  import torch
31
+ from safetensors.torch import load_file
28
32
  from torch.nn import Linear, Parameter
29
33
  from torch.nn import functional as F
30
34
 
35
+ from ...utils.logging import get_logger
36
+
37
+
38
+ logger = get_logger()
39
+
40
+ SUPPORTED_QUANTIZATIONS: Dict[str, list[str]] = {
41
+ "rbln": ["w4a16"],
42
+ }
43
+
44
+
45
+ class QuantizationManager:
46
+ # The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
47
+ # It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
48
+ RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
49
+
50
+ @staticmethod
51
+ def _raise_invalid_config_error(
52
+ key: str, value: str, valid_values: list[str], context: Optional[str] = None
53
+ ) -> None:
54
+ context_info = f" for {context}" if context else ""
55
+ valid_values_str = ", ".join(valid_values)
56
+ raise ValueError(f"Invalid {key}: {value}{context_info}. " f"Supported values are: {valid_values_str}")
57
+
58
+ @staticmethod
59
+ def validate_quantization_config(quantize_config: Optional[dict]) -> Optional[dict]:
60
+ if not quantize_config:
61
+ return None
62
+
63
+ q_format = quantize_config.get("format")
64
+ q_precision = quantize_config.get("precision")
65
+
66
+ if q_format not in SUPPORTED_QUANTIZATIONS:
67
+ QuantizationManager._raise_invalid_config_error(
68
+ "quantization format", q_format, list(SUPPORTED_QUANTIZATIONS.keys())
69
+ )
70
+
71
+ if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
72
+ QuantizationManager._raise_invalid_config_error(
73
+ "precision", q_precision, SUPPORTED_QUANTIZATIONS[q_format], q_format
74
+ )
75
+
76
+ return quantize_config
77
+
78
+ @classmethod
79
+ def _set_env_var(cls, name: str, value: str) -> None:
80
+ os.environ[name] = value
81
+
82
+ @classmethod
83
+ def _unset_env_var(cls, name: str) -> None:
84
+ os.environ.pop(name, None)
85
+
86
+ @classmethod
87
+ def set_quantization_env(cls, quantize_config: Optional[dict]) -> Optional[str]:
88
+ quantize_config = cls.validate_quantization_config(quantize_config)
89
+ if quantize_config:
90
+ q_precision: str = quantize_config["precision"]
91
+ quant_bits = q_precision.split("w")[1].split("a")[0]
92
+ cls._set_env_var(cls.RBLN_QUANT_BITS_ENV, quant_bits)
93
+ return cls.RBLN_QUANT_BITS_ENV
94
+ return None
95
+
96
+ @classmethod
97
+ def reset_quantization_env(cls, env_var_name: Optional[str]) -> None:
98
+ if env_var_name:
99
+ cls._unset_env_var(env_var_name)
100
+
101
+ @classmethod
102
+ def with_quantization_env(cls, func: Callable) -> Callable:
103
+ @functools.wraps(func)
104
+ def wrapper(*args, **kwargs):
105
+ quantize_config = kwargs.get("quantize_config")
106
+ quantize_env_var = cls.set_quantization_env(quantize_config)
107
+ try:
108
+ return func(*args, **kwargs)
109
+ finally:
110
+ cls.reset_quantization_env(quantize_env_var)
111
+
112
+ return wrapper
113
+
31
114
 
32
115
  # Constants
33
116
  QUANTIZED_WEIGHTS = {
34
- "q_proj", "k_proj", "v_proj", "o_proj",
35
- "gate_proj", "up_proj", "down_proj",
117
+ "q_proj",
118
+ "k_proj",
119
+ "v_proj",
120
+ "o_proj",
121
+ "gate_proj",
122
+ "up_proj",
123
+ "down_proj",
36
124
  }
37
125
 
38
126
 
39
- def update_layers_to_quantized(module: torch.nn.Module) -> None:
127
+ def prepare_model_for_quantization(model: torch.nn.Module, model_id: str, n_layer: Optional[int] = None) -> None:
128
+ """
129
+ Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
130
+ """
131
+ update_layers_to_quantize(model)
132
+ load_weights(model, model_id, n_layer)
133
+
134
+
135
+ def update_layers_to_quantize(module: torch.nn.Module) -> None:
40
136
  """
41
137
  Updates specified linear layers to quantized (qlinear) layers in the given module.
42
138
  """
@@ -49,7 +145,33 @@ def update_layers_to_quantized(module: torch.nn.Module) -> None:
49
145
  processed_layers.append(name)
50
146
 
51
147
  if processed_layers:
52
- print(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
148
+ logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
149
+
150
+
151
+ def load_weights(model, model_id, n_layer=None):
152
+ """
153
+ Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
154
+ """
155
+
156
+ model_params = dict(model.named_parameters(recurse=True))
157
+ model_buffers = dict(model.named_buffers(recurse=True))
158
+ safetensor_files = glob.glob(f"{model_id}/*.safetensors")
159
+
160
+ target_layers = list(range(n_layer)) if n_layer is not None else None
161
+
162
+ for safetensor_file in safetensor_files:
163
+ file_data = load_file(safetensor_file)
164
+ for key, value in file_data.items():
165
+ if target_layers is not None:
166
+ parts = key.split(".")
167
+
168
+ if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
169
+ continue
170
+
171
+ if key in model_params:
172
+ model_params[key].data.copy_(value)
173
+ elif key in model_buffers:
174
+ model_buffers[key].data.copy_(value)
53
175
 
54
176
 
55
177
  def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
@@ -81,6 +203,7 @@ def create_qlinear(layer: Linear) -> Linear:
81
203
  """
82
204
  Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
83
205
  """
206
+
84
207
  def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
85
208
  if inputs.dtype != self.scales.dtype:
86
209
  raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
@@ -0,0 +1,59 @@
1
+ from functools import wraps
2
+
3
+ from .logging import get_logger
4
+
5
+
6
+ logger = get_logger(__name__)
7
+
8
+
9
+ def remove_compile_time_kwargs(func):
10
+ """
11
+ Decorator to handle compile-time parameters during inference.
12
+
13
+ For RBLN-optimized pipelines, several parameters must be determined during compilation
14
+ and cannot be modified during inference. This decorator:
15
+ 1. Removes and warns about LoRA scale in cross_attention_kwargs
16
+ 2. Removes and warns about image dimension parameters (height, width)
17
+
18
+ Args:
19
+ func: The pipeline's __call__ method to be wrapped
20
+ """
21
+
22
+ @wraps(func)
23
+ def wrapper(self, *args, **kwargs):
24
+ height_exists = "height" in kwargs and kwargs["height"] is not None
25
+ width_exists = "width" in kwargs and kwargs["width"] is not None
26
+ compiled_image_size = self.vae.image_size
27
+ if height_exists or width_exists:
28
+ if kwargs["height"] == compiled_image_size[0] and kwargs["width"] == compiled_image_size[1]:
29
+ pass
30
+ else:
31
+ logger.warning(
32
+ "Image dimension parameters (`height`, `width`) will be ignored during inference. "
33
+ "Image dimensions must be specified during model compilation using from_pretrained()."
34
+ )
35
+ kwargs.pop("width", None)
36
+ kwargs.pop("height", None)
37
+
38
+ if "cross_attention_kwargs" in kwargs:
39
+ cross_attention_kwargs = kwargs.get("cross_attention_kwargs")
40
+ if not cross_attention_kwargs:
41
+ return func(self, *args, **kwargs)
42
+
43
+ has_scale = "scale" in cross_attention_kwargs
44
+ if has_scale:
45
+ logger.warning(
46
+ "LoRA scale in cross_attention_kwargs will be ignored during inference. "
47
+ "To adjust LoRA scale, specify it during model compilation using from_pretrained()."
48
+ )
49
+
50
+ # If scale is the only key, set to None
51
+ # Otherwise, remove scale and preserve other settings
52
+ if len(cross_attention_kwargs) == 1:
53
+ kwargs["cross_attention_kwargs"] = None
54
+ else:
55
+ kwargs["cross_attention_kwargs"].pop("scale")
56
+
57
+ return func(self, *args, **kwargs)
58
+
59
+ return wrapper
@@ -0,0 +1,131 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import os
25
+ from pathlib import Path
26
+ from typing import List, Optional, Union
27
+
28
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download
29
+
30
+
31
+ class PushToHubMixin:
32
+ def push_to_hub(
33
+ self,
34
+ save_directory: str,
35
+ repository_id: str,
36
+ private: Optional[bool] = None,
37
+ use_auth_token: Union[bool, str] = True,
38
+ ) -> str:
39
+ huggingface_token = _get_huggingface_token(use_auth_token)
40
+ api = HfApi()
41
+
42
+ api.create_repo(
43
+ token=huggingface_token,
44
+ repo_id=repository_id,
45
+ exist_ok=True,
46
+ private=private,
47
+ )
48
+ for path, subdirs, files in os.walk(save_directory):
49
+ for name in files:
50
+ local_file_path = os.path.join(path, name)
51
+ _, hub_file_path = os.path.split(local_file_path)
52
+ # FIXME: when huggingface_hub fixes the return of upload_file
53
+ try:
54
+ api.upload_file(
55
+ token=huggingface_token,
56
+ repo_id=f"{repository_id}",
57
+ path_or_fileobj=os.path.join(os.getcwd(), local_file_path),
58
+ path_in_repo=hub_file_path,
59
+ )
60
+ except KeyError:
61
+ pass
62
+ except NameError:
63
+ pass
64
+
65
+
66
+ def pull_compiled_model_from_hub(
67
+ model_id: Union[str, Path],
68
+ subfolder: str,
69
+ use_auth_token: Optional[Union[bool, str]],
70
+ revision: Optional[str],
71
+ cache_dir: Optional[str],
72
+ force_download: bool,
73
+ local_files_only: bool,
74
+ ) -> Path:
75
+ """Pull model files from the Hugging Face Hub."""
76
+ huggingface_token = _get_huggingface_token(use_auth_token)
77
+ repo_files = list(
78
+ map(
79
+ Path,
80
+ HfApi().list_repo_files(model_id, revision=revision, token=huggingface_token),
81
+ )
82
+ )
83
+
84
+ pattern_rbln = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
85
+ rbln_files = [p for p in repo_files if p.match(pattern_rbln)]
86
+
87
+ pattern_config = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
88
+ rbln_config_filenames = [p for p in repo_files if p.match(pattern_config)]
89
+
90
+ validate_files(rbln_files, rbln_config_filenames, f"repository {model_id}")
91
+
92
+ filenames = [str(path) for path in repo_files]
93
+
94
+ for filename in filenames:
95
+ rbln_config_cache_path = hf_hub_download(
96
+ repo_id=model_id,
97
+ filename=filename,
98
+ subfolder=subfolder,
99
+ use_auth_token=use_auth_token,
100
+ revision=revision,
101
+ cache_dir=cache_dir,
102
+ force_download=force_download,
103
+ local_files_only=local_files_only,
104
+ )
105
+
106
+ return Path(rbln_config_cache_path).parent
107
+
108
+
109
+ def validate_files(
110
+ files: List[Path],
111
+ config_files: List[Path],
112
+ location: str,
113
+ ):
114
+ """Validate the presence and count of required files."""
115
+ if len(files) == 0:
116
+ raise FileNotFoundError(f"Could not find any rbln model file in {location}")
117
+
118
+ if len(config_files) == 0:
119
+ raise FileNotFoundError(f"Could not find `rbln_config.json` file in {location}")
120
+
121
+ if len(config_files) > 1:
122
+ raise FileExistsError(f"Multiple rbln_config.json files found in {location}. This is not expected.")
123
+
124
+
125
+ def _get_huggingface_token(use_auth_token: Union[bool, str]) -> str:
126
+ if isinstance(use_auth_token, str):
127
+ return use_auth_token
128
+ elif use_auth_token:
129
+ return HfFolder.get_token()
130
+ else:
131
+ raise ValueError("`use_auth_token` must be provided to interact with the Hugging Face Hub.")
@@ -37,6 +37,27 @@ class VersionCompat:
37
37
 
38
38
 
39
39
  RBLN_VERSION_COMPATS = {
40
+ "0.1.15": [
41
+ VersionCompat(
42
+ package_name="rebel-compiler",
43
+ min_version="0.6.2",
44
+ max_version="0.6.3",
45
+ ),
46
+ ],
47
+ "0.1.14": [
48
+ VersionCompat(
49
+ package_name="rebel-compiler",
50
+ min_version="0.6.2",
51
+ max_version="0.6.3",
52
+ ),
53
+ ],
54
+ "0.1.13": [
55
+ VersionCompat(
56
+ package_name="rebel-compiler",
57
+ min_version="0.6.0",
58
+ max_version="0.6.2",
59
+ ),
60
+ ],
40
61
  "0.1.12": [
41
62
  VersionCompat(
42
63
  package_name="rebel-compiler",