optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (105) hide show
  1. optimum/rbln/__init__.py +36 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +20 -4
  4. optimum/rbln/diffusers/__init__.py +7 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  22. optimum/rbln/diffusers/pipelines/auto_pipeline.py +237 -0
  23. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  24. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  27. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  28. optimum/rbln/modeling.py +3 -2
  29. optimum/rbln/modeling_base.py +29 -4
  30. optimum/rbln/ops/attn.py +158 -0
  31. optimum/rbln/ops/flash_attn.py +166 -0
  32. optimum/rbln/transformers/__init__.py +28 -0
  33. optimum/rbln/transformers/configuration_generic.py +6 -4
  34. optimum/rbln/transformers/modeling_generic.py +13 -8
  35. optimum/rbln/transformers/modeling_outputs.py +37 -0
  36. optimum/rbln/transformers/models/__init__.py +35 -16
  37. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +14 -0
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  40. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  41. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  43. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  44. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +7 -6
  45. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  46. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  47. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  48. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
  51. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  52. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  53. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
  54. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  55. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  56. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  57. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  58. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  59. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
  60. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  61. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +64 -258
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  63. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  64. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
  65. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
  66. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  67. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  68. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  69. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  75. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  76. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  77. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  78. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  79. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  80. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  81. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  82. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  83. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  84. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  85. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  86. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  87. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  88. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  89. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  90. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  91. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  92. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  93. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  94. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  95. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  96. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  97. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  99. optimum/rbln/transformers/utils/rbln_quantization.py +365 -65
  100. optimum/rbln/utils/runtime_utils.py +3 -3
  101. optimum/rbln/utils/submodule.py +10 -4
  102. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/METADATA +1 -1
  103. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/RECORD +105 -89
  104. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/WHEEL +0 -0
  105. {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3.dist-info}/licenses/LICENSE +0 -0
@@ -14,9 +14,10 @@
14
14
 
15
15
  import glob
16
16
  import os
17
- from typing import Any, Dict, Optional, Union
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
+ from huggingface_hub import hf_hub_download, list_repo_files
20
21
  from safetensors.torch import load_file
21
22
  from torch.nn import Linear, Parameter
22
23
  from torch.nn import functional as F
@@ -28,23 +29,47 @@ from ...utils.logging import get_logger
28
29
  logger = get_logger()
29
30
 
30
31
 
32
+ # Constants
33
+ QUANTIZED_WEIGHTS = {
34
+ "q_proj",
35
+ "k_proj",
36
+ "v_proj",
37
+ "o_proj",
38
+ "gate_proj",
39
+ "up_proj",
40
+ "down_proj",
41
+ }
42
+
43
+ # Common alias sets seen in community checkpoints
44
+ VARIANT_ALIASES: Dict[str, List[str]] = {
45
+ "weight_scale": ["weight_scale", "scales", "w_scale", "scale"],
46
+ "input_scale": ["input_scale", "act_scale", "activation_scale", "a_scale"],
47
+ "kv_scale": ["kv_scale", "kv_scales"],
48
+ "k_scale": ["k_scale", "k_scales"],
49
+ "v_scale": ["v_scale", "v_scales"],
50
+ }
51
+
52
+
31
53
  class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
32
54
  SUPPORTED_FORMATS = ["rbln"]
33
- SUPPORTED_WEIGHTS = ["int4", "fp16"]
34
- SUPPORTED_ACTIVATIONS = ["fp16"]
35
-
36
- # The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
37
- # It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
55
+ SUPPORTED_WEIGHTS = ["int4", "int8", "fp8", "fp16"]
56
+ SUPPORTED_ACTIVATIONS = ["int8", "fp8", "fp16"]
57
+ SUPPORTED_KVCACHES = ["fp8", "fp16"]
38
58
  RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
39
59
 
40
60
  def __init__(
41
61
  self,
42
62
  format: Optional[str] = None,
43
- precision: Optional[str] = None,
44
63
  weights: Optional[str] = None,
45
64
  activations: Optional[str] = None,
65
+ kv_caches: Optional[str] = None,
66
+ *,
67
+ precision: Optional[str] = None,
46
68
  ):
47
- self.format = format
69
+ self.format = format or "rbln"
70
+ if self.format not in self.SUPPORTED_FORMATS:
71
+ raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
72
+
48
73
  if precision is not None:
49
74
  logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
50
75
  if any(precision_arg is not None for precision_arg in (weights, activations)):
@@ -58,6 +83,7 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
58
83
 
59
84
  self.weights = weights or "fp16"
60
85
  self.activations = activations or "fp16"
86
+ self.kv_caches = kv_caches or "fp16"
61
87
  self._validate()
62
88
 
63
89
  def _validate(self):
@@ -69,37 +95,47 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
69
95
  raise ValueError(
70
96
  f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
71
97
  )
98
+ if self.kv_caches not in self.SUPPORTED_KVCACHES:
99
+ raise ValueError(
100
+ f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
101
+ )
72
102
  if self.weights == "fp16" and self.activations == "fp16":
73
- raise ValueError("weights and activations cannot be both fp16. It is meaningless.")
103
+ raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
74
104
 
75
105
  def _prepare_for_serialization(self) -> Dict[str, Any]:
76
106
  return {
77
107
  "format": self.format,
78
108
  "weights": self.weights,
79
109
  "activations": self.activations,
110
+ "kv_caches": self.kv_caches,
80
111
  }
81
112
 
82
113
  def maybe_set_quantization_env(self):
83
- quant_bits = None
84
114
  if self.weights == "int4":
85
- quant_bits = "4"
86
- os.environ[self.RBLN_QUANT_BITS_ENV] = quant_bits
115
+ os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
87
116
 
88
117
  def maybe_reset_quantization_env(self):
89
118
  if self.RBLN_QUANT_BITS_ENV in os.environ:
90
119
  os.environ.pop(self.RBLN_QUANT_BITS_ENV)
91
120
 
92
121
 
93
- # Constants
94
- QUANTIZED_WEIGHTS = {
95
- "q_proj",
96
- "k_proj",
97
- "v_proj",
98
- "o_proj",
99
- "gate_proj",
100
- "up_proj",
101
- "down_proj",
102
- }
122
+ class QuantizedLayerFactory:
123
+ def __init__(self, quantization_config: RBLNQuantizationConfig):
124
+ self.quantization_config = quantization_config
125
+
126
+ def create_linear(self, layer: Linear) -> Linear:
127
+ if self.quantization_config.weights in ["int4", "int8"]:
128
+ return self.create_qlinear(layer)
129
+ elif self.quantization_config.weights == "fp8":
130
+ return self.create_fp8linear(layer)
131
+ else:
132
+ raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
133
+
134
+ def create_qlinear(self, layer: Linear) -> Linear:
135
+ return create_qlinear(layer, self.quantization_config)
136
+
137
+ def create_fp8linear(self, layer: Linear) -> Linear:
138
+ return create_fp8linear(layer, self.quantization_config)
103
139
 
104
140
 
105
141
  def prepare_model_for_quantization(
@@ -111,64 +147,51 @@ def prepare_model_for_quantization(
111
147
  cache_dir: Optional[str] = None,
112
148
  force_download: bool = False,
113
149
  local_files_only: bool = False,
150
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
114
151
  ) -> torch.nn.Module:
115
152
  """
116
153
  Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
117
154
  """
118
- update_layers_to_quantize(model)
119
- load_weights(
120
- model,
155
+
156
+ # 1. Load weight files
157
+ safetensor_files = load_weight_files(
121
158
  model_id,
122
- n_layer,
123
159
  use_auth_token=use_auth_token,
124
160
  revision=revision,
125
161
  cache_dir=cache_dir,
126
162
  force_download=force_download,
127
163
  local_files_only=local_files_only,
128
164
  )
129
- return model
130
-
131
165
 
132
- def update_layers_to_quantize(module: torch.nn.Module) -> None:
133
- """
134
- Updates specified linear layers to quantized (qlinear) layers in the given module.
135
- """
136
-
137
- logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
138
- processed_layers = []
166
+ # 2. Update linear layers based on the quantization config
167
+ update_layers_to_quantize(model, rbln_quantization)
139
168
 
140
- for name, layer in module.named_modules():
141
- if is_target_for_qlinear_replacement(name, layer):
142
- parent_module, layer_name = get_parent_and_child(module, name)
143
- setattr(parent_module, layer_name, create_qlinear(layer))
144
- processed_layers.append(name)
169
+ # 3. Load weights into model parameters
170
+ load_weights_from_files(
171
+ model,
172
+ safetensor_files,
173
+ n_layer,
174
+ rbln_quantization=rbln_quantization,
175
+ )
145
176
 
146
- if processed_layers:
147
- logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
177
+ return model
148
178
 
149
179
 
150
- def load_weights(
151
- model,
152
- model_id,
153
- n_layer=None,
154
- use_auth_token=None,
155
- revision=None,
156
- cache_dir=None,
157
- force_download=False,
158
- local_files_only=False,
159
- ):
180
+ def load_weight_files(
181
+ model_id: str,
182
+ use_auth_token: Optional[Union[bool, str]] = None,
183
+ revision: Optional[str] = None,
184
+ cache_dir: Optional[str] = None,
185
+ force_download: bool = False,
186
+ local_files_only: bool = False,
187
+ ) -> list[str]:
160
188
  """
161
- Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
189
+ Discover and download safetensors files for the given model id.
162
190
  """
163
191
 
164
- model_params = dict(model.named_parameters(recurse=True))
165
- model_buffers = dict(model.named_buffers(recurse=True))
166
-
167
192
  if os.path.isdir(model_id):
168
193
  safetensor_files = glob.glob(f"{model_id}/*.safetensors")
169
194
  else:
170
- from huggingface_hub import hf_hub_download, list_repo_files
171
-
172
195
  try:
173
196
  # List all files in the repository
174
197
  repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
@@ -195,27 +218,238 @@ def load_weights(
195
218
  if not safetensor_files:
196
219
  raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
197
220
 
221
+ return safetensor_files
222
+
223
+
224
+ def update_layers_to_quantize(
225
+ module: torch.nn.Module,
226
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
227
+ ) -> None:
228
+ """
229
+ Updates specified linear layers to quantized (qlinear) layers in the given module.
230
+ """
231
+
232
+ processed_layers = []
233
+ quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
234
+
235
+ for name, layer in module.named_modules():
236
+ if is_target_for_qlinear_replacement(name, layer):
237
+ parent_module, layer_name = get_parent_and_child(module, name)
238
+ setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
239
+ processed_layers.append(name)
240
+
241
+ if processed_layers:
242
+ logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
243
+
244
+
245
+ def _last_segment(key: str) -> str:
246
+ parts = key.split(".")
247
+ return parts[-1]
248
+
249
+
250
+ def _replace_last_with(key: str, new_tail: str) -> str:
251
+ parts = key.split(".")
252
+ return ".".join(parts[:-1] + new_tail.split("."))
253
+
254
+
255
+ def _matches_any_alias(key: str, kind: str) -> bool:
256
+ tail = _last_segment(key)
257
+ return tail in VARIANT_ALIASES.get(kind, [])
258
+
259
+
260
+ def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
261
+ if t.ndim == 0:
262
+ return t
263
+ return t.reshape(-1).amax()
264
+
265
+
266
+ def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
267
+ s = scale
268
+ if s.ndim == 0:
269
+ # scalar -> expand to [out_features, 1]
270
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
271
+ if s.ndim == 1:
272
+ if s.numel() == 1:
273
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
274
+ if s.numel() == out_features:
275
+ return s.reshape(out_features, 1).contiguous()
276
+ # fallback: reduce to scalar then expand
277
+ v = _reduce_to_scalar(s)
278
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
279
+ if s.ndim == 2:
280
+ if s.shape == (out_features, 1):
281
+ return s.contiguous()
282
+ if s.shape == (1, out_features):
283
+ return s.transpose(0, 1).contiguous()
284
+ # fallback: reduce to [out_features] on non-out dims if possible
285
+ if s.shape[0] == out_features:
286
+ v = s
287
+ while v.ndim > 2:
288
+ v = v.amax(dim=-1)
289
+ if v.shape[-1] != 1:
290
+ v = v.amax(dim=-1, keepdim=True)
291
+ return v.contiguous()
292
+ # otherwise reduce to scalar then expand
293
+ v = _reduce_to_scalar(s)
294
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
295
+ # high-rank: reduce to scalar then expand
296
+ v = _reduce_to_scalar(s)
297
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
298
+
299
+
300
+ def _kv_split_items(base_key: str, tensor: torch.Tensor) -> List[Tuple[str, torch.Tensor]]:
301
+ # base_key is the original key whose last token was 'kv_scale'
302
+ # We produce keys with 'k_proj.k_scale' and 'v_proj.v_scale'
303
+ if tensor.ndim == 1 and tensor.numel() >= 2:
304
+ tk, tv = tensor[0], tensor[1]
305
+ elif tensor.ndim == 2 and tensor.shape[0] >= 2 and tensor.shape[1] == 1:
306
+ tk, tv = tensor[0, 0], tensor[1, 0]
307
+ else:
308
+ tk = tv = tensor
309
+ k_key = _replace_last_with(base_key, "k_proj.k_scale")
310
+ v_key = _replace_last_with(base_key, "v_proj.v_scale")
311
+ return [(k_key, tk), (v_key, tv)]
312
+
313
+
314
+ def canonicalize_checkpoint_items(
315
+ model: torch.nn.Module,
316
+ items: Iterable[Tuple[str, torch.Tensor]],
317
+ rbln_quantization: Optional[RBLNQuantizationConfig],
318
+ ) -> List[Tuple[str, torch.Tensor]]:
319
+ params = dict(model.named_parameters(recurse=True))
320
+ results: List[Tuple[str, torch.Tensor]] = []
321
+
322
+ for key, value in items:
323
+ t = value
324
+ # Normalize weight scale variants
325
+ if _matches_any_alias(key, "weight_scale"):
326
+ # rename last token to the canonical weight scale key
327
+ target_key = _replace_last_with(key, "weight_scale")
328
+
329
+ # Determine associated weight param to infer shape
330
+ weight_key = _replace_last_with(target_key, "weight")
331
+ out_features = None
332
+ if weight_key in params:
333
+ wshape = params[weight_key].shape
334
+ if len(wshape) == 2:
335
+ out_features = int(wshape[0])
336
+
337
+ if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
338
+ t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
339
+ elif rbln_quantization.weights == "fp8":
340
+ # Use a conservative scalar scale to ensure broadcastability
341
+ t = _reduce_to_scalar(t.to(torch.float32))
342
+ else:
343
+ t = t.to(torch.float32)
344
+
345
+ results.append((target_key, t))
346
+ continue
347
+
348
+ # Normalize input/activation scale variants
349
+ if _matches_any_alias(key, "input_scale"):
350
+ target_key = _replace_last_with(key, "input_scale")
351
+ t = _reduce_to_scalar(t.to(torch.float32))
352
+ results.append((target_key, t))
353
+ continue
354
+
355
+ # KV scale handling
356
+ if _matches_any_alias(key, "kv_scale"):
357
+ # For quark-like formats, expand to k/v
358
+ kv_items = _kv_split_items(key, t.to(torch.float32))
359
+ for k2, v2 in kv_items:
360
+ results.append((k2, v2))
361
+ continue
362
+
363
+ if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
364
+ results.append((key, t.to(torch.float32)))
365
+ continue
366
+
367
+ # Default: passthrough
368
+ results.append((key, t))
369
+
370
+ return results
371
+
372
+
373
+ def load_weights_from_files(
374
+ model: torch.nn.Module,
375
+ safetensor_files: list[str],
376
+ n_layer: Optional[int] = None,
377
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
378
+ ):
379
+ """
380
+ Load safetensor file data directly into the model from provided safetensor files,
381
+ filtering by layer if n_layer is provided.
382
+ """
383
+
384
+ model_params = dict(model.named_parameters(recurse=True))
385
+ model_buffers = dict(model.named_buffers(recurse=True))
386
+
198
387
  target_layers = list(range(n_layer)) if n_layer is not None else None
199
388
 
200
389
  unloaded_keys = []
390
+ loaded_input_scale = False
391
+ loaded_kv_scale = False
392
+ loaded_weight_scale = False
393
+
201
394
  for safetensor_file in safetensor_files:
202
395
  file_data = load_file(safetensor_file)
203
- for key, value in file_data.items():
396
+
397
+ # Normalize all (key, tensor) pairs to the internal schema
398
+ normalized_items = canonicalize_checkpoint_items(
399
+ model=model,
400
+ items=file_data.items(),
401
+ rbln_quantization=rbln_quantization,
402
+ )
403
+
404
+ for key, value in normalized_items:
405
+ # Track which types of scales were observed (post-normalization)
406
+ if key.endswith("input_scale"):
407
+ loaded_input_scale = True
408
+ if key.endswith("weight_scale"):
409
+ loaded_weight_scale = True
410
+ if key.endswith("k_scale") or key.endswith("v_scale"):
411
+ loaded_kv_scale = True
412
+
413
+ # Filter by layer index if requested
204
414
  if target_layers is not None:
205
415
  parts = key.split(".")
206
-
207
416
  if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
208
417
  continue
209
418
 
419
+ # Copy into parameters or buffers
210
420
  if key in model_params:
421
+ # Ensure dtype compatibility
422
+ if model_params[key].dtype != value.dtype:
423
+ value = value.to(model_params[key].dtype)
211
424
  model_params[key].data.copy_(value)
212
425
  elif key in model_buffers:
426
+ if model_buffers[key].dtype != value.dtype:
427
+ value = value.to(model_buffers[key].dtype)
213
428
  model_buffers[key].data.copy_(value)
214
429
  else:
215
430
  unloaded_keys.append(key)
216
431
 
217
432
  if len(unloaded_keys) > 0:
218
433
  logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
434
+ if not loaded_input_scale and rbln_quantization.activations == "fp8":
435
+ raise ValueError(
436
+ "No input_scale found in the checkpoint. Did you use the correct quantization config? "
437
+ "If you are using fp8 quantization, you need to use the correct quantization config."
438
+ )
439
+ if not loaded_weight_scale and rbln_quantization.weights == "fp8":
440
+ raise ValueError(
441
+ "No weight_scale found in the checkpoint. Did you use the correct quantization config? "
442
+ "If you are using fp8 quantization, you need to use the correct quantization config."
443
+ )
444
+ if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
445
+ raise ValueError(
446
+ "No kv_scale found in the checkpoint. Did you use the correct quantization config? "
447
+ "If you are using fp8 quantization, you need to use the correct quantization config."
448
+ )
449
+ if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
450
+ logger.warning(
451
+ "kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
452
+ )
219
453
 
220
454
 
221
455
  def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
@@ -225,6 +459,10 @@ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -
225
459
  return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
226
460
 
227
461
 
462
+ def is_target_for_adding_kv_scales(layer_name: str) -> bool:
463
+ return layer_name.split(".")[-1] in ["self_attn"]
464
+
465
+
228
466
  def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
229
467
  """
230
468
  Splits the full layer name to retrieve the parent module and the child layer.
@@ -243,22 +481,84 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
243
481
  return obj
244
482
 
245
483
 
246
- def create_qlinear(layer: Linear) -> Linear:
484
+ def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
247
485
  """
248
486
  Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
249
487
  """
250
488
 
251
489
  def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
252
- if inputs.dtype != self.scales.dtype:
253
- raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
490
+ weight_scale = self.weight_scale
491
+ if inputs.dtype != weight_scale.dtype:
492
+ raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")
254
493
 
255
494
  w_fp = self.weight.type(inputs.dtype)
256
- w_fp *= self.scales.view(-1, 1)
495
+ w_fp *= weight_scale.view(-1, 1)
257
496
  return F.linear(inputs, w_fp, self.bias)
258
497
 
259
498
  # Convert weight to int8 and add scale parameter
260
499
  layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
261
- layer.scales = Parameter(torch.ones(layer.out_features, dtype=torch.float32), requires_grad=False)
500
+ layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
262
501
  layer.forward = lambda inputs: qlinear_forward(layer, inputs)
263
502
 
264
503
  return layer
504
+
505
+
506
+ def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
507
+ """
508
+ Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
509
+ """
510
+
511
+ def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
512
+ finfo = torch.finfo(torch.float8_e4m3fn)
513
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
514
+ return qweight
515
+
516
+ def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
517
+ A = A.type(out_dtype)
518
+ B = B.type(out_dtype)
519
+
520
+ if A_scale is not None:
521
+ A *= A_scale
522
+ if B_scale is not None:
523
+ B *= B_scale.to(out_dtype)
524
+
525
+ output = torch.nn.functional.linear(A, B, bias=bias)
526
+ return output
527
+
528
+ def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
529
+ if self.input_scale:
530
+ input = static_per_tensor_quantize(x, self.input_scale)
531
+ else:
532
+ input = x
533
+
534
+ if self.weight_scale:
535
+ # broadcast weight_scale to vector
536
+ weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
537
+ else:
538
+ weight_scale = None
539
+ output = fp8_gemm(
540
+ A=input,
541
+ A_scale=self.input_scale,
542
+ B=self.weight,
543
+ B_scale=weight_scale,
544
+ bias=self.bias,
545
+ out_dtype=x.dtype,
546
+ )
547
+
548
+ return output
549
+
550
+ layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
551
+ layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
552
+
553
+ if rbln_quantization.activations == "fp8":
554
+ layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
555
+ else:
556
+ layer.input_scale = None
557
+
558
+ if rbln_quantization.kv_caches == "fp8":
559
+ layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
560
+ layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
561
+
562
+ layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
563
+
564
+ return layer
@@ -14,7 +14,7 @@
14
14
 
15
15
  import re
16
16
  import threading
17
- from typing import Any, Dict, List, Optional, Union
17
+ from typing import Any, List, Optional, Union
18
18
 
19
19
  import rebel
20
20
  import torch
@@ -94,7 +94,7 @@ class RBLNPytorchRuntime:
94
94
  def __call__(self, *args: Any, **kwds: Any) -> Any:
95
95
  return self.forward(*args, **kwds)
96
96
 
97
- def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
97
+ def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
98
98
  # filtering useless args or kwarg such as None.
99
99
  args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
100
100
  kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
@@ -142,7 +142,7 @@ class UnavailableRuntime:
142
142
  """Returns an iterator with self as the only item."""
143
143
  return iter([self])
144
144
 
145
- def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
145
+ def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
146
146
  """Raises a detailed RuntimeError explaining why inference cannot be performed."""
147
147
  raise RuntimeError(
148
148
  "Cannot perform inference: RBLN runtime is not available.\n\n"
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, Any, Dict, List, Type
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
17
17
 
18
18
  from transformers import PretrainedConfig
19
19
 
@@ -22,7 +22,7 @@ from ..utils.model_utils import get_rbln_model_cls
22
22
 
23
23
 
24
24
  if TYPE_CHECKING:
25
- from transformers import PreTrainedModel
25
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
26
26
 
27
27
  from ..modeling import RBLNModel
28
28
 
@@ -42,7 +42,12 @@ class SubModulesMixin:
42
42
  setattr(self, submodule_meta["name"], submodule)
43
43
 
44
44
  @classmethod
45
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
45
+ def _update_submodule_config(
46
+ cls,
47
+ model: "PreTrainedModel",
48
+ rbln_config: RBLNModelConfig,
49
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
50
+ ):
46
51
  return rbln_config
47
52
 
48
53
  @classmethod
@@ -51,6 +56,7 @@ class SubModulesMixin:
51
56
  ) -> List["RBLNModel"]:
52
57
  rbln_submodules = []
53
58
  submodule_prefix = getattr(cls, "_rbln_submodule_prefix", None)
59
+ preprocessors = kwargs.pop("preprocessors", [])
54
60
 
55
61
  for submodule in cls._rbln_submodules:
56
62
  submodule_name = submodule["name"]
@@ -69,7 +75,7 @@ class SubModulesMixin:
69
75
  submodule_rbln_config = submodule_rbln_config_class(**submodule_rbln_config)
70
76
  setattr(rbln_config, submodule_name, submodule_rbln_config)
71
77
 
72
- submodule_rbln_config = submodule_cls._update_submodule_config(model, submodule_rbln_config)
78
+ submodule_rbln_config = submodule_cls._update_submodule_config(model, submodule_rbln_config, preprocessors)
73
79
 
74
80
  rbln_submodule = submodule_cls.from_model(
75
81
  model=torch_submodule,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.2a7
3
+ Version: 0.8.3
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai