optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -14,37 +14,67 @@
14
14
 
15
15
  import glob
16
16
  import os
17
- from typing import Any, Dict, Optional, Union
17
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union
18
18
 
19
19
  import torch
20
+ from huggingface_hub import hf_hub_download, list_repo_files
20
21
  from safetensors.torch import load_file
21
22
  from torch.nn import Linear, Parameter
22
23
  from torch.nn import functional as F
24
+ from transformers import AutoConfig
25
+ from transformers.modeling_utils import get_state_dict_dtype, no_init_weights
23
26
 
24
27
  from ...configuration_utils import RBLNSerializableConfigProtocol
25
28
  from ...utils.logging import get_logger
26
29
 
27
30
 
31
+ if TYPE_CHECKING:
32
+ from transformers.models.auto.modeling_auto import _BaseAutoModelClass
33
+
28
34
  logger = get_logger()
29
35
 
30
36
 
37
+ # Constants
38
+ QUANTIZED_WEIGHTS = {
39
+ "q_proj",
40
+ "k_proj",
41
+ "v_proj",
42
+ "o_proj",
43
+ "gate_proj",
44
+ "up_proj",
45
+ "down_proj",
46
+ }
47
+
48
+ # Common alias sets seen in community checkpoints
49
+ VARIANT_ALIASES: Dict[str, List[str]] = {
50
+ "weight_scale": ["weight_scale", "scales", "w_scale", "scale"],
51
+ "input_scale": ["input_scale", "act_scale", "activation_scale", "a_scale"],
52
+ "kv_scale": ["kv_scale", "kv_scales"],
53
+ "k_scale": ["k_scale", "k_scales"],
54
+ "v_scale": ["v_scale", "v_scales"],
55
+ }
56
+
57
+
31
58
  class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
32
59
  SUPPORTED_FORMATS = ["rbln"]
33
- SUPPORTED_WEIGHTS = ["int4", "fp16"]
34
- SUPPORTED_ACTIVATIONS = ["fp16"]
35
-
36
- # The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
37
- # It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
60
+ SUPPORTED_WEIGHTS = ["int4", "int8", "fp8", "fp16"]
61
+ SUPPORTED_ACTIVATIONS = ["int8", "fp8", "fp16"]
62
+ SUPPORTED_KVCACHES = ["fp8", "fp16"]
38
63
  RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
39
64
 
40
65
  def __init__(
41
66
  self,
42
67
  format: Optional[str] = None,
43
- precision: Optional[str] = None,
44
68
  weights: Optional[str] = None,
45
69
  activations: Optional[str] = None,
70
+ kv_caches: Optional[str] = None,
71
+ *,
72
+ precision: Optional[str] = None,
46
73
  ):
47
- self.format = format
74
+ self.format = format or "rbln"
75
+ if self.format not in self.SUPPORTED_FORMATS:
76
+ raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
77
+
48
78
  if precision is not None:
49
79
  logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
50
80
  if any(precision_arg is not None for precision_arg in (weights, activations)):
@@ -58,6 +88,7 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
58
88
 
59
89
  self.weights = weights or "fp16"
60
90
  self.activations = activations or "fp16"
91
+ self.kv_caches = kv_caches or "fp16"
61
92
  self._validate()
62
93
 
63
94
  def _validate(self):
@@ -69,106 +100,126 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
69
100
  raise ValueError(
70
101
  f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
71
102
  )
103
+ if self.kv_caches not in self.SUPPORTED_KVCACHES:
104
+ raise ValueError(
105
+ f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
106
+ )
72
107
  if self.weights == "fp16" and self.activations == "fp16":
73
- raise ValueError("weights and activations cannot be both fp16. It is meaningless.")
108
+ raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
74
109
 
75
110
  def _prepare_for_serialization(self) -> Dict[str, Any]:
76
111
  return {
77
112
  "format": self.format,
78
113
  "weights": self.weights,
79
114
  "activations": self.activations,
115
+ "kv_caches": self.kv_caches,
80
116
  }
81
117
 
82
118
  def maybe_set_quantization_env(self):
83
- quant_bits = None
84
119
  if self.weights == "int4":
85
- quant_bits = "4"
86
- os.environ[self.RBLN_QUANT_BITS_ENV] = quant_bits
120
+ os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
87
121
 
88
122
  def maybe_reset_quantization_env(self):
89
123
  if self.RBLN_QUANT_BITS_ENV in os.environ:
90
124
  os.environ.pop(self.RBLN_QUANT_BITS_ENV)
91
125
 
92
126
 
93
- # Constants
94
- QUANTIZED_WEIGHTS = {
95
- "q_proj",
96
- "k_proj",
97
- "v_proj",
98
- "o_proj",
99
- "gate_proj",
100
- "up_proj",
101
- "down_proj",
102
- }
127
+ class QuantizedLayerFactory:
128
+ def __init__(self, quantization_config: RBLNQuantizationConfig):
129
+ self.quantization_config = quantization_config
103
130
 
131
+ def create_linear(self, layer: Linear) -> Linear:
132
+ if self.quantization_config.weights in ["int4", "int8"]:
133
+ return self.create_qlinear(layer)
134
+ elif self.quantization_config.weights == "fp8":
135
+ return self.create_fp8linear(layer)
136
+ else:
137
+ raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
104
138
 
105
- def prepare_model_for_quantization(
106
- model: torch.nn.Module,
139
+ def create_qlinear(self, layer: Linear) -> Linear:
140
+ return create_qlinear(layer, self.quantization_config)
141
+
142
+ def create_fp8linear(self, layer: Linear) -> Linear:
143
+ return create_fp8linear(layer, self.quantization_config)
144
+
145
+
146
+ def get_quantized_model(
147
+ hf_auto_model_class: Type["_BaseAutoModelClass"],
107
148
  model_id: str,
108
- n_layer: Optional[int] = None,
109
149
  use_auth_token: Optional[Union[bool, str]] = None,
110
150
  revision: Optional[str] = None,
111
151
  cache_dir: Optional[str] = None,
112
152
  force_download: bool = False,
113
153
  local_files_only: bool = False,
114
- ) -> torch.nn.Module:
154
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
155
+ **kwargs,
156
+ ):
115
157
  """
116
- Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
158
+ Get a quantized model from a model class and model id.
117
159
  """
118
- update_layers_to_quantize(model)
119
- load_weights(
120
- model,
160
+ # torch_dtype should not be passed to AutoConfig.from_pretrained
161
+ # since it doesn't support 'auto'
162
+ torch_dtype = kwargs.pop("torch_dtype", None)
163
+ if torch_dtype is not None:
164
+ logger.warning(
165
+ "torch_dtype is not supported for quantized models. "
166
+ "It will be ignored and the dtype of the model will be determined by the weights."
167
+ )
168
+ torch_dtype = None
169
+
170
+ # get paths of safetensors files in the model repo
171
+ safetensor_files = load_weight_files(
121
172
  model_id,
122
- n_layer,
123
173
  use_auth_token=use_auth_token,
124
174
  revision=revision,
125
175
  cache_dir=cache_dir,
126
176
  force_download=force_download,
127
177
  local_files_only=local_files_only,
128
178
  )
129
- return model
130
179
 
180
+ # load safetensors files into memory
181
+ safetensors = [load_file(safetensor_file) for safetensor_file in safetensor_files]
131
182
 
132
- def update_layers_to_quantize(module: torch.nn.Module) -> None:
133
- """
134
- Updates specified linear layers to quantized (qlinear) layers in the given module.
135
- """
183
+ # get the dtype of the model from the first safetensor file
184
+ torch_dtype = get_state_dict_dtype(safetensors[0])
136
185
 
137
- logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
138
- processed_layers = []
186
+ config = AutoConfig.from_pretrained(
187
+ model_id,
188
+ use_auth_token=use_auth_token,
189
+ revision=revision,
190
+ cache_dir=cache_dir,
191
+ force_download=force_download,
192
+ local_files_only=local_files_only,
193
+ **kwargs,
194
+ )
139
195
 
140
- for name, layer in module.named_modules():
141
- if is_target_for_qlinear_replacement(name, layer):
142
- parent_module, layer_name = get_parent_and_child(module, name)
143
- setattr(parent_module, layer_name, create_qlinear(layer))
144
- processed_layers.append(name)
196
+ with no_init_weights():
197
+ model = hf_auto_model_class.from_config(config, torch_dtype=torch_dtype)
145
198
 
146
- if processed_layers:
147
- logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
199
+ # Quantize the model
200
+ update_layers_to_quantize(model, rbln_quantization)
148
201
 
202
+ # Load weights into the model
203
+ load_weights_from_files(model, safetensors, rbln_quantization)
149
204
 
150
- def load_weights(
151
- model,
152
- model_id,
153
- n_layer=None,
154
- use_auth_token=None,
155
- revision=None,
156
- cache_dir=None,
157
- force_download=False,
158
- local_files_only=False,
159
- ):
205
+ return model
206
+
207
+
208
+ def load_weight_files(
209
+ model_id: str,
210
+ use_auth_token: Optional[Union[bool, str]] = None,
211
+ revision: Optional[str] = None,
212
+ cache_dir: Optional[str] = None,
213
+ force_download: bool = False,
214
+ local_files_only: bool = False,
215
+ ) -> list[str]:
160
216
  """
161
- Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
217
+ Discover and download safetensors files for the given model id.
162
218
  """
163
219
 
164
- model_params = dict(model.named_parameters(recurse=True))
165
- model_buffers = dict(model.named_buffers(recurse=True))
166
-
167
220
  if os.path.isdir(model_id):
168
221
  safetensor_files = glob.glob(f"{model_id}/*.safetensors")
169
222
  else:
170
- from huggingface_hub import hf_hub_download, list_repo_files
171
-
172
223
  try:
173
224
  # List all files in the repository
174
225
  repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
@@ -195,27 +246,226 @@ def load_weights(
195
246
  if not safetensor_files:
196
247
  raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
197
248
 
198
- target_layers = list(range(n_layer)) if n_layer is not None else None
249
+ return safetensor_files
199
250
 
200
- unloaded_keys = []
201
- for safetensor_file in safetensor_files:
202
- file_data = load_file(safetensor_file)
203
- for key, value in file_data.items():
204
- if target_layers is not None:
205
- parts = key.split(".")
206
251
 
207
- if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
208
- continue
252
+ def update_layers_to_quantize(
253
+ module: torch.nn.Module,
254
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
255
+ ) -> None:
256
+ """
257
+ Updates specified linear layers to quantized (qlinear) layers in the given module.
258
+ """
259
+
260
+ processed_layers = []
261
+ quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
262
+
263
+ for name, layer in module.named_modules():
264
+ if is_target_for_qlinear_replacement(name, layer):
265
+ parent_module, layer_name = get_parent_and_child(module, name)
266
+ setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
267
+ processed_layers.append(name)
209
268
 
269
+ if processed_layers:
270
+ logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
271
+
272
+
273
+ def _last_segment(key: str) -> str:
274
+ parts = key.split(".")
275
+ return parts[-1]
276
+
277
+
278
+ def _replace_last_with(key: str, new_tail: str) -> str:
279
+ parts = key.split(".")
280
+ return ".".join(parts[:-1] + new_tail.split("."))
281
+
282
+
283
+ def _matches_any_alias(key: str, kind: str) -> bool:
284
+ tail = _last_segment(key)
285
+ return tail in VARIANT_ALIASES.get(kind, [])
286
+
287
+
288
+ def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
289
+ if t.ndim == 0:
290
+ return t
291
+ return t.reshape(-1).amax()
292
+
293
+
294
+ def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
295
+ s = scale
296
+ if s.ndim == 0:
297
+ # scalar -> expand to [out_features, 1]
298
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
299
+ if s.ndim == 1:
300
+ if s.numel() == 1:
301
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
302
+ if s.numel() == out_features:
303
+ return s.reshape(out_features, 1).contiguous()
304
+ # fallback: reduce to scalar then expand
305
+ v = _reduce_to_scalar(s)
306
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
307
+ if s.ndim == 2:
308
+ if s.shape == (out_features, 1):
309
+ return s.contiguous()
310
+ if s.shape == (1, out_features):
311
+ return s.transpose(0, 1).contiguous()
312
+ # fallback: reduce to [out_features] on non-out dims if possible
313
+ if s.shape[0] == out_features:
314
+ v = s
315
+ while v.ndim > 2:
316
+ v = v.amax(dim=-1)
317
+ if v.shape[-1] != 1:
318
+ v = v.amax(dim=-1, keepdim=True)
319
+ return v.contiguous()
320
+ # otherwise reduce to scalar then expand
321
+ v = _reduce_to_scalar(s)
322
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
323
+ # high-rank: reduce to scalar then expand
324
+ v = _reduce_to_scalar(s)
325
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
326
+
327
+
328
+ def _kv_split_items(base_key: str, tensor: torch.Tensor) -> List[Tuple[str, torch.Tensor]]:
329
+ # base_key is the original key whose last token was 'kv_scale'
330
+ # We produce keys with 'k_proj.k_scale' and 'v_proj.v_scale'
331
+ if tensor.ndim == 1 and tensor.numel() >= 2:
332
+ tk, tv = tensor[0], tensor[1]
333
+ elif tensor.ndim == 2 and tensor.shape[0] >= 2 and tensor.shape[1] == 1:
334
+ tk, tv = tensor[0, 0], tensor[1, 0]
335
+ else:
336
+ tk = tv = tensor
337
+ k_key = _replace_last_with(base_key, "k_proj.k_scale")
338
+ v_key = _replace_last_with(base_key, "v_proj.v_scale")
339
+ return [(k_key, tk), (v_key, tv)]
340
+
341
+
342
+ def canonicalize_checkpoint_items(
343
+ model: torch.nn.Module,
344
+ items: Iterable[Tuple[str, torch.Tensor]],
345
+ rbln_quantization: Optional[RBLNQuantizationConfig],
346
+ ) -> List[Tuple[str, torch.Tensor]]:
347
+ params = dict(model.named_parameters(recurse=True))
348
+ results: List[Tuple[str, torch.Tensor]] = []
349
+
350
+ for key, value in items:
351
+ t = value
352
+ # Normalize weight scale variants
353
+ if _matches_any_alias(key, "weight_scale"):
354
+ # rename last token to the canonical weight scale key
355
+ target_key = _replace_last_with(key, "weight_scale")
356
+
357
+ # Determine associated weight param to infer shape
358
+ weight_key = _replace_last_with(target_key, "weight")
359
+ out_features = None
360
+ if weight_key in params:
361
+ wshape = params[weight_key].shape
362
+ if len(wshape) == 2:
363
+ out_features = int(wshape[0])
364
+
365
+ if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
366
+ t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
367
+ elif rbln_quantization.weights == "fp8":
368
+ # Use a conservative scalar scale to ensure broadcastability
369
+ t = _reduce_to_scalar(t.to(torch.float32))
370
+ else:
371
+ t = t.to(torch.float32)
372
+
373
+ results.append((target_key, t))
374
+ continue
375
+
376
+ # Normalize input/activation scale variants
377
+ if _matches_any_alias(key, "input_scale"):
378
+ target_key = _replace_last_with(key, "input_scale")
379
+ t = _reduce_to_scalar(t.to(torch.float32))
380
+ results.append((target_key, t))
381
+ continue
382
+
383
+ # KV scale handling
384
+ if _matches_any_alias(key, "kv_scale"):
385
+ # For quark-like formats, expand to k/v
386
+ kv_items = _kv_split_items(key, t.to(torch.float32))
387
+ for k2, v2 in kv_items:
388
+ results.append((k2, v2))
389
+ continue
390
+
391
+ if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
392
+ results.append((key, t.to(torch.float32)))
393
+ continue
394
+
395
+ # Default: passthrough
396
+ results.append((key, t))
397
+
398
+ return results
399
+
400
+
401
+ def load_weights_from_files(
402
+ model: torch.nn.Module,
403
+ safetensors: List[Dict[str, torch.Tensor]],
404
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
405
+ ):
406
+ """
407
+ Load safetensor file data directly into the model from provided safetensor files.
408
+ """
409
+
410
+ model_params = dict(model.named_parameters(recurse=True))
411
+ model_buffers = dict(model.named_buffers(recurse=True))
412
+
413
+ unloaded_keys = []
414
+ loaded_input_scale = False
415
+ loaded_kv_scale = False
416
+ loaded_weight_scale = False
417
+
418
+ for safetensor in safetensors:
419
+ # Normalize all (key, tensor) pairs to the internal schema
420
+ normalized_items = canonicalize_checkpoint_items(
421
+ model=model,
422
+ items=safetensor.items(),
423
+ rbln_quantization=rbln_quantization,
424
+ )
425
+
426
+ for key, value in normalized_items:
427
+ # Track which types of scales were observed (post-normalization)
428
+ if key.endswith("input_scale"):
429
+ loaded_input_scale = True
430
+ if key.endswith("weight_scale"):
431
+ loaded_weight_scale = True
432
+ if key.endswith("k_scale") or key.endswith("v_scale"):
433
+ loaded_kv_scale = True
434
+
435
+ # Copy into parameters or buffers
210
436
  if key in model_params:
437
+ # Ensure dtype compatibility
438
+ if model_params[key].dtype != value.dtype:
439
+ value = value.to(model_params[key].dtype)
211
440
  model_params[key].data.copy_(value)
212
441
  elif key in model_buffers:
442
+ if model_buffers[key].dtype != value.dtype:
443
+ value = value.to(model_buffers[key].dtype)
213
444
  model_buffers[key].data.copy_(value)
214
445
  else:
215
446
  unloaded_keys.append(key)
216
447
 
217
448
  if len(unloaded_keys) > 0:
218
449
  logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
450
+ if not loaded_input_scale and rbln_quantization.activations == "fp8":
451
+ raise ValueError(
452
+ "No input_scale found in the checkpoint. Did you use the correct quantization config? "
453
+ "If you are using fp8 quantization, you need to use the correct quantization config."
454
+ )
455
+ if not loaded_weight_scale and rbln_quantization.weights == "fp8":
456
+ raise ValueError(
457
+ "No weight_scale found in the checkpoint. Did you use the correct quantization config? "
458
+ "If you are using fp8 quantization, you need to use the correct quantization config."
459
+ )
460
+ if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
461
+ raise ValueError(
462
+ "No kv_scale found in the checkpoint. Did you use the correct quantization config? "
463
+ "If you are using fp8 quantization, you need to use the correct quantization config."
464
+ )
465
+ if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
466
+ logger.warning(
467
+ "kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
468
+ )
219
469
 
220
470
 
221
471
  def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
@@ -225,6 +475,10 @@ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -
225
475
  return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
226
476
 
227
477
 
478
+ def is_target_for_adding_kv_scales(layer_name: str) -> bool:
479
+ return layer_name.split(".")[-1] in ["self_attn"]
480
+
481
+
228
482
  def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
229
483
  """
230
484
  Splits the full layer name to retrieve the parent module and the child layer.
@@ -243,22 +497,84 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
243
497
  return obj
244
498
 
245
499
 
246
- def create_qlinear(layer: Linear) -> Linear:
500
+ def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
247
501
  """
248
502
  Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
249
503
  """
250
504
 
251
505
  def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
252
- if inputs.dtype != self.scales.dtype:
253
- raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
506
+ weight_scale = self.weight_scale
507
+ if inputs.dtype != weight_scale.dtype:
508
+ raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")
254
509
 
255
510
  w_fp = self.weight.type(inputs.dtype)
256
- w_fp *= self.scales.view(-1, 1)
511
+ w_fp *= weight_scale.view(-1, 1)
257
512
  return F.linear(inputs, w_fp, self.bias)
258
513
 
259
514
  # Convert weight to int8 and add scale parameter
260
515
  layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
261
- layer.scales = Parameter(torch.ones(layer.out_features, dtype=torch.float32), requires_grad=False)
516
+ layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
262
517
  layer.forward = lambda inputs: qlinear_forward(layer, inputs)
263
518
 
264
519
  return layer
520
+
521
+
522
+ def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
523
+ """
524
+ Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
525
+ """
526
+
527
+ def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
528
+ finfo = torch.finfo(torch.float8_e4m3fn)
529
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
530
+ return qweight
531
+
532
+ def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
533
+ A = A.type(out_dtype)
534
+ B = B.type(out_dtype)
535
+
536
+ if A_scale is not None:
537
+ A *= A_scale
538
+ if B_scale is not None:
539
+ B *= B_scale.to(out_dtype)
540
+
541
+ output = torch.nn.functional.linear(A, B, bias=bias)
542
+ return output
543
+
544
+ def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
545
+ if self.input_scale:
546
+ input = static_per_tensor_quantize(x, self.input_scale)
547
+ else:
548
+ input = x
549
+
550
+ if self.weight_scale:
551
+ # broadcast weight_scale to vector
552
+ weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
553
+ else:
554
+ weight_scale = None
555
+ output = fp8_gemm(
556
+ A=input,
557
+ A_scale=self.input_scale,
558
+ B=self.weight,
559
+ B_scale=weight_scale,
560
+ bias=self.bias,
561
+ out_dtype=x.dtype,
562
+ )
563
+
564
+ return output
565
+
566
+ layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
567
+ layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
568
+
569
+ if rbln_quantization.activations == "fp8":
570
+ layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
571
+ else:
572
+ layer.input_scale = None
573
+
574
+ if rbln_quantization.kv_caches == "fp8":
575
+ layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
576
+ layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
577
+
578
+ layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
579
+
580
+ return layer
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
18
+
19
+ from torch.nn import Module
20
+
21
+ from ...modeling import RBLNModel
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ import rebel
26
+
27
+
28
+ class LoopProcessor(Module, ABC):
29
+ def __init__(self, model: Union[RBLNModel, "rebel.Runtime"]):
30
+ super().__init__()
31
+ self.model = model
32
+
33
+ def __repr__(self) -> str:
34
+ return repr(self.model)
35
+
36
+ def _is_batch_implemented(self) -> bool:
37
+ return self._forward_batch.__func__ is not LoopProcessor._forward_batch
38
+
39
+ def forward(self, *args, force_loop: bool = False, **kwargs) -> Any:
40
+ if not force_loop and self._is_batch_implemented():
41
+ return self._forward_batch(*args, **kwargs)
42
+ else:
43
+ return self._forward_loop(*args, **kwargs)
44
+
45
+ def _forward_loop(self, *args, **kwargs) -> Any:
46
+ batch_size = self._get_batch_size(*args, **kwargs)
47
+
48
+ if not isinstance(batch_size, int) or batch_size == 0:
49
+ return self._process_outputs([])
50
+
51
+ common_inputs = self._prepare_inputs_before_loop(*args, **kwargs)
52
+
53
+ outputs = []
54
+ for i in range(batch_size):
55
+ item_args, item_kwargs = self._prepare_inputs_for_iteration(i, common_inputs, *args, **kwargs)
56
+ item_output = self.model(*item_args, **item_kwargs)
57
+ outputs.append(item_output)
58
+
59
+ return self._process_outputs(outputs, **kwargs)
60
+
61
+ def _forward_batch(self, *args, **kwargs) -> Any:
62
+ raise NotImplementedError("The batch processing logic (_forward_batch) is not implemented in this class.")
63
+
64
+ @abstractmethod
65
+ def _get_batch_size(self, *args, **kwargs) -> int:
66
+ pass
67
+
68
+ @abstractmethod
69
+ def _prepare_inputs_for_iteration(
70
+ self, index: int, common_inputs: Dict[str, Any], *args, **kwargs
71
+ ) -> Tuple[List[Any], Dict[str, Any]]:
72
+ pass
73
+
74
+ def _prepare_inputs_before_loop(self, *args, **kwargs) -> Dict[str, Any]:
75
+ pass
76
+
77
+ @abstractmethod
78
+ def _process_outputs(self, outputs: List[Any], **kwargs) -> Any:
79
+ pass