optimum-rbln 0.8.3a0__py3-none-any.whl → 0.8.3a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

optimum/rbln/__init__.py CHANGED
@@ -169,6 +169,9 @@ _import_structure = {
169
169
  "RBLNAutoencoderKLConfig",
170
170
  "RBLNAutoencoderKLCosmos",
171
171
  "RBLNAutoencoderKLCosmosConfig",
172
+ "RBLNAutoPipelineForImage2Image",
173
+ "RBLNAutoPipelineForInpainting",
174
+ "RBLNAutoPipelineForText2Image",
172
175
  "RBLNControlNetModel",
173
176
  "RBLNControlNetModelConfig",
174
177
  "RBLNCosmosTextToWorldPipeline",
@@ -238,6 +241,9 @@ if TYPE_CHECKING:
238
241
  RBLNAutoencoderKLConfig,
239
242
  RBLNAutoencoderKLCosmos,
240
243
  RBLNAutoencoderKLCosmosConfig,
244
+ RBLNAutoPipelineForImage2Image,
245
+ RBLNAutoPipelineForInpainting,
246
+ RBLNAutoPipelineForText2Image,
241
247
  RBLNControlNetModel,
242
248
  RBLNControlNetModelConfig,
243
249
  RBLNCosmosSafetyChecker,
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.8.3a0'
32
- __version_tuple__ = version_tuple = (0, 8, 3, 'a0')
31
+ __version__ = version = '0.8.3a2'
32
+ __version_tuple__ = version_tuple = (0, 8, 3, 'a2')
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -59,6 +59,9 @@ _import_structure = {
59
59
  "RBLNVQModelConfig",
60
60
  ],
61
61
  "pipelines": [
62
+ "RBLNAutoPipelineForImage2Image",
63
+ "RBLNAutoPipelineForInpainting",
64
+ "RBLNAutoPipelineForText2Image",
62
65
  "RBLNCosmosTextToWorldPipeline",
63
66
  "RBLNCosmosVideoToWorldPipeline",
64
67
  "RBLNCosmosSafetyChecker",
@@ -144,6 +147,9 @@ if TYPE_CHECKING:
144
147
  RBLNVQModel,
145
148
  )
146
149
  from .pipelines import (
150
+ RBLNAutoPipelineForImage2Image,
151
+ RBLNAutoPipelineForInpainting,
152
+ RBLNAutoPipelineForText2Image,
147
153
  RBLNCosmosSafetyChecker,
148
154
  RBLNCosmosTextToWorldPipeline,
149
155
  RBLNCosmosVideoToWorldPipeline,
@@ -18,6 +18,11 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
+ "auto_pipeline": [
22
+ "RBLNAutoPipelineForImage2Image",
23
+ "RBLNAutoPipelineForInpainting",
24
+ "RBLNAutoPipelineForText2Image",
25
+ ],
21
26
  "controlnet": [
22
27
  "RBLNMultiControlNetModel",
23
28
  "RBLNStableDiffusionControlNetImg2ImgPipeline",
@@ -56,6 +61,11 @@ _import_structure = {
56
61
  ],
57
62
  }
58
63
  if TYPE_CHECKING:
64
+ from .auto_pipeline import (
65
+ RBLNAutoPipelineForImage2Image,
66
+ RBLNAutoPipelineForInpainting,
67
+ RBLNAutoPipelineForText2Image,
68
+ )
59
69
  from .controlnet import (
60
70
  RBLNMultiControlNetModel,
61
71
  RBLNStableDiffusionControlNetImg2ImgPipeline,
@@ -0,0 +1,237 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import importlib
17
+ from typing import Type
18
+
19
+ from diffusers.models.controlnets import ControlNetUnionModel
20
+ from diffusers.pipelines.auto_pipeline import (
21
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
22
+ AUTO_INPAINT_PIPELINES_MAPPING,
23
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
24
+ AutoPipelineForImage2Image,
25
+ AutoPipelineForInpainting,
26
+ AutoPipelineForText2Image,
27
+ _get_task_class,
28
+ )
29
+ from huggingface_hub.utils import validate_hf_hub_args
30
+
31
+ from optimum.rbln.modeling_base import RBLNBaseModel
32
+ from optimum.rbln.utils.model_utils import (
33
+ MODEL_MAPPING,
34
+ convert_hf_to_rbln_model_name,
35
+ convert_rbln_to_hf_model_name,
36
+ get_rbln_model_cls,
37
+ )
38
+
39
+
40
+ class RBLNAutoPipelineBase:
41
+ _model_mapping = None
42
+ _model_mapping_names = None
43
+
44
+ @classmethod
45
+ def get_rbln_cls(cls, pretrained_model_name_or_path, export=True, **kwargs):
46
+ if export:
47
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
48
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
49
+ else:
50
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
51
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names:
52
+ raise ValueError(
53
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
54
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
55
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
56
+ )
57
+
58
+ try:
59
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
60
+ except AttributeError as e:
61
+ raise AttributeError(
62
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
63
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
64
+ ) from e
65
+
66
+ return rbln_cls
67
+
68
+ @classmethod
69
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
70
+ """
71
+ Retrieve the path to the compiled model directory for a given RBLN model.
72
+
73
+ Args:
74
+ pretrained_model_name_or_path (str): Identifier of the model.
75
+
76
+ Returns:
77
+ str: Path to the compiled model directory.
78
+ """
79
+ model_index_config = cls.load_config(pretrained_model_name_or_path)
80
+
81
+ if "_class_name" not in model_index_config:
82
+ raise ValueError(
83
+ "The `_class_name` field is missing from model_index_config. This is unexpected and should be reported as an issue. "
84
+ "Please use the `from_pretrained()` method of the appropriate class to load this model."
85
+ )
86
+
87
+ return model_index_config["_class_name"]
88
+
89
+ @classmethod
90
+ def infer_hf_model_class(
91
+ cls,
92
+ pretrained_model_or_path,
93
+ cache_dir=None,
94
+ force_download=False,
95
+ proxies=None,
96
+ token=None,
97
+ local_files_only=False,
98
+ revision=None,
99
+ **kwargs,
100
+ ):
101
+ config = cls.load_config(
102
+ pretrained_model_or_path,
103
+ cache_dir=cache_dir,
104
+ force_download=force_download,
105
+ proxies=proxies,
106
+ token=token,
107
+ local_files_only=local_files_only,
108
+ revision=revision,
109
+ )
110
+ pipeline_key_name = cls.get_pipeline_key_name(config, **kwargs)
111
+
112
+ pipeline_cls = _get_task_class(cls._model_mapping, pipeline_key_name)
113
+
114
+ return pipeline_cls
115
+
116
+ @classmethod
117
+ def get_pipeline_key_name(cls, config, **kwargs):
118
+ orig_class_name = config["_class_name"]
119
+ if "ControlPipeline" in orig_class_name:
120
+ to_replace = "ControlPipeline"
121
+ else:
122
+ to_replace = "Pipeline"
123
+
124
+ if "controlnet" in kwargs:
125
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
126
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
127
+ else:
128
+ orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
129
+ if "enable_pag" in kwargs:
130
+ enable_pag = kwargs.pop("enable_pag")
131
+ if enable_pag:
132
+ orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
133
+
134
+ return orig_class_name
135
+
136
+ @classmethod
137
+ @validate_hf_hub_args
138
+ def from_pretrained(cls, model_id, **kwargs):
139
+ rbln_cls = cls.get_rbln_cls(model_id, **kwargs)
140
+ return rbln_cls.from_pretrained(model_id, **kwargs)
141
+
142
+ @classmethod
143
+ def from_model(cls, model, **kwargs):
144
+ rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
145
+ return rbln_cls.from_model(model, **kwargs)
146
+
147
+ @staticmethod
148
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
149
+ """
150
+ Register a new RBLN model class.
151
+
152
+ Args:
153
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
154
+ exist_ok (bool): Whether to allow registering an already registered model.
155
+ """
156
+ if not issubclass(rbln_cls, RBLNBaseModel):
157
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
158
+
159
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
160
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
161
+ if not exist_ok:
162
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
163
+
164
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
165
+
166
+
167
+ class RBLNAutoPipelineForText2Image(RBLNAutoPipelineBase, AutoPipelineForText2Image):
168
+ _model_mapping = AUTO_TEXT2IMAGE_PIPELINES_MAPPING
169
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()}
170
+
171
+
172
+ class RBLNAutoPipelineForImage2Image(RBLNAutoPipelineBase, AutoPipelineForImage2Image):
173
+ _model_mapping = AUTO_IMAGE2IMAGE_PIPELINES_MAPPING
174
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()}
175
+
176
+ @classmethod
177
+ def get_pipeline_key_name(cls, config, **kwargs):
178
+ orig_class_name = config["_class_name"]
179
+ # the `orig_class_name` can be:
180
+ # `- *Pipeline` (for regular text-to-image checkpoint)
181
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
182
+ # `- *Img2ImgPipeline` (for refiner checkpoint)
183
+ if "Img2Img" in orig_class_name:
184
+ to_replace = "Img2ImgPipeline"
185
+ elif "ControlPipeline" in orig_class_name:
186
+ to_replace = "ControlPipeline"
187
+ else:
188
+ to_replace = "Pipeline"
189
+
190
+ if "controlnet" in kwargs:
191
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
192
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
193
+ else:
194
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
195
+ if "enable_pag" in kwargs:
196
+ enable_pag = kwargs.pop("enable_pag")
197
+ if enable_pag:
198
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
199
+
200
+ if to_replace == "ControlPipeline":
201
+ orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
202
+
203
+ return orig_class_name
204
+
205
+
206
+ class RBLNAutoPipelineForInpainting(RBLNAutoPipelineBase, AutoPipelineForInpainting):
207
+ _model_mapping = AUTO_INPAINT_PIPELINES_MAPPING
208
+ _model_mapping_names = {x[0]: x[1].__name__ for x in AUTO_INPAINT_PIPELINES_MAPPING.items()}
209
+
210
+ @classmethod
211
+ def get_pipeline_key_name(cls, config, **kwargs):
212
+ orig_class_name = config["_class_name"]
213
+
214
+ # The `orig_class_name`` can be:
215
+ # `- *InpaintPipeline` (for inpaint-specific checkpoint)
216
+ # - `*ControlPipeline` (for Flux tools specific checkpoint)
217
+ # - or *Pipeline (for regular text-to-image checkpoint)
218
+ if "Inpaint" in orig_class_name:
219
+ to_replace = "InpaintPipeline"
220
+ elif "ControlPipeline" in orig_class_name:
221
+ to_replace = "ControlPipeline"
222
+ else:
223
+ to_replace = "Pipeline"
224
+
225
+ if "controlnet" in kwargs:
226
+ if isinstance(kwargs["controlnet"], ControlNetUnionModel):
227
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
228
+ else:
229
+ orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
230
+ if "enable_pag" in kwargs:
231
+ enable_pag = kwargs.pop("enable_pag")
232
+ if enable_pag:
233
+ orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
234
+ if to_replace == "ControlPipeline":
235
+ orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
236
+
237
+ return orig_class_name
@@ -126,7 +126,14 @@ class T5Decoder(Seq2SeqDecoder):
126
126
  b_size = attention_mask.shape[0]
127
127
  batch_decoder_position_bias = []
128
128
  for i in range(b_size):
129
- batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
129
+ if torch.compiler.is_exporting():
130
+ cache_pos = cache_position[i][0].item()
131
+ torch._check_is_size(cache_pos)
132
+ torch._check(cache_pos >= 0)
133
+ torch._check(cache_pos < self._dec_position_bias.shape[2])
134
+ else:
135
+ cache_pos = cache_position[i][0]
136
+ batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
130
137
  batch_decoder_position_bias.append(batch_position_bias)
131
138
  position_bias = torch.cat(batch_decoder_position_bias, dim=0)
132
139
 
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import glob
16
- import json
17
16
  import os
18
- from typing import Any, Dict, Optional, Union
17
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
19
18
 
20
19
  import torch
21
20
  from huggingface_hub import hf_hub_download, list_repo_files
@@ -30,10 +29,31 @@ from ...utils.logging import get_logger
30
29
  logger = get_logger()
31
30
 
32
31
 
32
+ # Constants
33
+ QUANTIZED_WEIGHTS = {
34
+ "q_proj",
35
+ "k_proj",
36
+ "v_proj",
37
+ "o_proj",
38
+ "gate_proj",
39
+ "up_proj",
40
+ "down_proj",
41
+ }
42
+
43
+ # Common alias sets seen in community checkpoints
44
+ VARIANT_ALIASES: Dict[str, List[str]] = {
45
+ "weight_scale": ["weight_scale", "scales", "w_scale", "scale"],
46
+ "input_scale": ["input_scale", "act_scale", "activation_scale", "a_scale"],
47
+ "kv_scale": ["kv_scale", "kv_scales"],
48
+ "k_scale": ["k_scale", "k_scales"],
49
+ "v_scale": ["v_scale", "v_scales"],
50
+ }
51
+
52
+
33
53
  class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
34
54
  SUPPORTED_FORMATS = ["rbln"]
35
- SUPPORTED_WEIGHTS = ["int4", "fp8", "fp16"]
36
- SUPPORTED_ACTIVATIONS = ["fp8", "fp16"]
55
+ SUPPORTED_WEIGHTS = ["int4", "int8", "fp8", "fp16"]
56
+ SUPPORTED_ACTIVATIONS = ["int8", "fp8", "fp16"]
37
57
  SUPPORTED_KVCACHES = ["fp8", "fp16"]
38
58
  RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
39
59
 
@@ -64,7 +84,6 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
64
84
  self.weights = weights or "fp16"
65
85
  self.activations = activations or "fp16"
66
86
  self.kv_caches = kv_caches or "fp16"
67
-
68
87
  self._validate()
69
88
 
70
89
  def _validate(self):
@@ -105,7 +124,7 @@ class QuantizedLayerFactory:
105
124
  self.quantization_config = quantization_config
106
125
 
107
126
  def create_linear(self, layer: Linear) -> Linear:
108
- if self.quantization_config.weights == "int4":
127
+ if self.quantization_config.weights in ["int4", "int8"]:
109
128
  return self.create_qlinear(layer)
110
129
  elif self.quantization_config.weights == "fp8":
111
130
  return self.create_fp8linear(layer)
@@ -119,18 +138,6 @@ class QuantizedLayerFactory:
119
138
  return create_fp8linear(layer, self.quantization_config)
120
139
 
121
140
 
122
- # Constants
123
- QUANTIZED_WEIGHTS = {
124
- "q_proj",
125
- "k_proj",
126
- "v_proj",
127
- "o_proj",
128
- "gate_proj",
129
- "up_proj",
130
- "down_proj",
131
- }
132
-
133
-
134
141
  def prepare_model_for_quantization(
135
142
  model: torch.nn.Module,
136
143
  model_id: str,
@@ -146,8 +153,8 @@ def prepare_model_for_quantization(
146
153
  Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
147
154
  """
148
155
 
149
- # 1. Load weight files and safetensors.index.json
150
- safetensor_files, index_data = load_weight_files_and_index(
156
+ # 1. Load weight files
157
+ safetensor_files = load_weight_files(
151
158
  model_id,
152
159
  use_auth_token=use_auth_token,
153
160
  revision=revision,
@@ -156,43 +163,34 @@ def prepare_model_for_quantization(
156
163
  local_files_only=local_files_only,
157
164
  )
158
165
 
159
- # 2. Determine format from safetensors.index.json
160
- determined_format = determine_format_from_index(index_data)
161
-
162
- # 3. Update linear layers based on the determined format
166
+ # 2. Update linear layers based on the quantization config
163
167
  update_layers_to_quantize(model, rbln_quantization)
164
168
 
165
- # 4. Load weights into model parameters
169
+ # 3. Load weights into model parameters
166
170
  load_weights_from_files(
167
171
  model,
168
172
  safetensor_files,
169
173
  n_layer,
170
174
  rbln_quantization=rbln_quantization,
171
- determined_format=determined_format,
172
175
  )
173
176
 
174
177
  return model
175
178
 
176
179
 
177
- def load_weight_files_and_index(
180
+ def load_weight_files(
178
181
  model_id: str,
179
182
  use_auth_token: Optional[Union[bool, str]] = None,
180
183
  revision: Optional[str] = None,
181
184
  cache_dir: Optional[str] = None,
182
185
  force_download: bool = False,
183
186
  local_files_only: bool = False,
184
- ) -> tuple[list[str], Optional[Dict]]:
187
+ ) -> list[str]:
185
188
  """
186
- Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
189
+ Discover and download safetensors files for the given model id.
187
190
  """
188
- index_data = None
189
191
 
190
192
  if os.path.isdir(model_id):
191
193
  safetensor_files = glob.glob(f"{model_id}/*.safetensors")
192
- index_path = os.path.join(model_id, "model.safetensors.index.json")
193
- if os.path.exists(index_path):
194
- with open(index_path, "r") as f:
195
- index_data = json.load(f)
196
194
  else:
197
195
  try:
198
196
  # List all files in the repository
@@ -213,20 +211,6 @@ def load_weight_files_and_index(
213
211
  local_files_only=local_files_only,
214
212
  )
215
213
  safetensor_files.append(downloaded_file)
216
- elif file == "model.safetensors.index.json":
217
- # Download the index file
218
- index_file = hf_hub_download(
219
- repo_id=model_id,
220
- filename=file,
221
- revision=revision,
222
- token=use_auth_token,
223
- cache_dir=cache_dir,
224
- force_download=force_download,
225
- local_files_only=local_files_only,
226
- )
227
-
228
- with open(index_file, "r") as f:
229
- index_data = json.load(f)
230
214
  except Exception as e:
231
215
  logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
232
216
  raise e
@@ -234,32 +218,7 @@ def load_weight_files_and_index(
234
218
  if not safetensor_files:
235
219
  raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
236
220
 
237
- return safetensor_files, index_data
238
-
239
-
240
- def determine_format_from_index(index_data: Optional[Dict]) -> str:
241
- """
242
- Determine the quantization format from safetensors.index.json data.
243
-
244
- Args:
245
- index_data: The loaded safetensors.index.json content
246
-
247
- Returns:
248
- str: The determined format string
249
- """
250
- if index_data is None:
251
- raise ValueError("safetensors.index.json not found")
252
- if "weight_map" not in index_data:
253
- raise ValueError("weight_map not found in safetensors.index.json")
254
-
255
- if any("self_attn.k_proj.k_scale" in key for key in index_data["weight_map"]):
256
- return "tensorrt"
257
- elif any("self_attn.kv_scale" in key for key in index_data["weight_map"]):
258
- return "quark"
259
- elif any("weight_scale" in key or "input_scale" in key for key in index_data["weight_map"]):
260
- return "default"
261
- else:
262
- raise ValueError("Unknown quantization format of the index data of weight map.")
221
+ return safetensor_files
263
222
 
264
223
 
265
224
  def update_layers_to_quantize(
@@ -283,12 +242,139 @@ def update_layers_to_quantize(
283
242
  logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
284
243
 
285
244
 
245
+ def _last_segment(key: str) -> str:
246
+ parts = key.split(".")
247
+ return parts[-1]
248
+
249
+
250
+ def _replace_last_with(key: str, new_tail: str) -> str:
251
+ parts = key.split(".")
252
+ return ".".join(parts[:-1] + new_tail.split("."))
253
+
254
+
255
+ def _matches_any_alias(key: str, kind: str) -> bool:
256
+ tail = _last_segment(key)
257
+ return tail in VARIANT_ALIASES.get(kind, [])
258
+
259
+
260
+ def _reduce_to_scalar(t: torch.Tensor) -> torch.Tensor:
261
+ if t.ndim == 0:
262
+ return t
263
+ return t.reshape(-1).amax()
264
+
265
+
266
+ def _coerce_per_out_channel_scale(scale: torch.Tensor, out_features: int) -> torch.Tensor:
267
+ s = scale
268
+ if s.ndim == 0:
269
+ # scalar -> expand to [out_features, 1]
270
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
271
+ if s.ndim == 1:
272
+ if s.numel() == 1:
273
+ return s.reshape(1, 1).expand(out_features, 1).contiguous()
274
+ if s.numel() == out_features:
275
+ return s.reshape(out_features, 1).contiguous()
276
+ # fallback: reduce to scalar then expand
277
+ v = _reduce_to_scalar(s)
278
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
279
+ if s.ndim == 2:
280
+ if s.shape == (out_features, 1):
281
+ return s.contiguous()
282
+ if s.shape == (1, out_features):
283
+ return s.transpose(0, 1).contiguous()
284
+ # fallback: reduce to [out_features] on non-out dims if possible
285
+ if s.shape[0] == out_features:
286
+ v = s
287
+ while v.ndim > 2:
288
+ v = v.amax(dim=-1)
289
+ if v.shape[-1] != 1:
290
+ v = v.amax(dim=-1, keepdim=True)
291
+ return v.contiguous()
292
+ # otherwise reduce to scalar then expand
293
+ v = _reduce_to_scalar(s)
294
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
295
+ # high-rank: reduce to scalar then expand
296
+ v = _reduce_to_scalar(s)
297
+ return v.reshape(1, 1).expand(out_features, 1).contiguous()
298
+
299
+
300
+ def _kv_split_items(base_key: str, tensor: torch.Tensor) -> List[Tuple[str, torch.Tensor]]:
301
+ # base_key is the original key whose last token was 'kv_scale'
302
+ # We produce keys with 'k_proj.k_scale' and 'v_proj.v_scale'
303
+ if tensor.ndim == 1 and tensor.numel() >= 2:
304
+ tk, tv = tensor[0], tensor[1]
305
+ elif tensor.ndim == 2 and tensor.shape[0] >= 2 and tensor.shape[1] == 1:
306
+ tk, tv = tensor[0, 0], tensor[1, 0]
307
+ else:
308
+ tk = tv = tensor
309
+ k_key = _replace_last_with(base_key, "k_proj.k_scale")
310
+ v_key = _replace_last_with(base_key, "v_proj.v_scale")
311
+ return [(k_key, tk), (v_key, tv)]
312
+
313
+
314
+ def canonicalize_checkpoint_items(
315
+ model: torch.nn.Module,
316
+ items: Iterable[Tuple[str, torch.Tensor]],
317
+ rbln_quantization: Optional[RBLNQuantizationConfig],
318
+ ) -> List[Tuple[str, torch.Tensor]]:
319
+ params = dict(model.named_parameters(recurse=True))
320
+ results: List[Tuple[str, torch.Tensor]] = []
321
+
322
+ for key, value in items:
323
+ t = value
324
+ # Normalize weight scale variants
325
+ if _matches_any_alias(key, "weight_scale"):
326
+ # rename last token to the canonical weight scale key
327
+ target_key = _replace_last_with(key, "weight_scale")
328
+
329
+ # Determine associated weight param to infer shape
330
+ weight_key = _replace_last_with(target_key, "weight")
331
+ out_features = None
332
+ if weight_key in params:
333
+ wshape = params[weight_key].shape
334
+ if len(wshape) == 2:
335
+ out_features = int(wshape[0])
336
+
337
+ if rbln_quantization.weights in ["int4", "int8"] and out_features is not None:
338
+ t = _coerce_per_out_channel_scale(t.to(torch.float32), out_features)
339
+ elif rbln_quantization.weights == "fp8":
340
+ # Use a conservative scalar scale to ensure broadcastability
341
+ t = _reduce_to_scalar(t.to(torch.float32))
342
+ else:
343
+ t = t.to(torch.float32)
344
+
345
+ results.append((target_key, t))
346
+ continue
347
+
348
+ # Normalize input/activation scale variants
349
+ if _matches_any_alias(key, "input_scale"):
350
+ target_key = _replace_last_with(key, "input_scale")
351
+ t = _reduce_to_scalar(t.to(torch.float32))
352
+ results.append((target_key, t))
353
+ continue
354
+
355
+ # KV scale handling
356
+ if _matches_any_alias(key, "kv_scale"):
357
+ # For quark-like formats, expand to k/v
358
+ kv_items = _kv_split_items(key, t.to(torch.float32))
359
+ for k2, v2 in kv_items:
360
+ results.append((k2, v2))
361
+ continue
362
+
363
+ if _matches_any_alias(key, "k_scale") or _matches_any_alias(key, "v_scale"):
364
+ results.append((key, t.to(torch.float32)))
365
+ continue
366
+
367
+ # Default: passthrough
368
+ results.append((key, t))
369
+
370
+ return results
371
+
372
+
286
373
  def load_weights_from_files(
287
374
  model: torch.nn.Module,
288
375
  safetensor_files: list[str],
289
376
  n_layer: Optional[int] = None,
290
377
  rbln_quantization: Optional[RBLNQuantizationConfig] = None,
291
- determined_format: Optional[str] = None,
292
378
  ):
293
379
  """
294
380
  Load safetensor file data directly into the model from provided safetensor files,
@@ -308,33 +394,43 @@ def load_weights_from_files(
308
394
  for safetensor_file in safetensor_files:
309
395
  file_data = load_file(safetensor_file)
310
396
 
311
- for key, value in file_data.items():
312
- loaded_input_scale = loaded_input_scale or "input_scale" in key
313
- loaded_weight_scale = loaded_weight_scale or "weight_scale" in key
314
- loaded_kv_scale = loaded_kv_scale or any(scale in key for scale in ["kv_scale", "k_scale", "v_scale"])
397
+ # Normalize all (key, tensor) pairs to the internal schema
398
+ normalized_items = canonicalize_checkpoint_items(
399
+ model=model,
400
+ items=file_data.items(),
401
+ rbln_quantization=rbln_quantization,
402
+ )
315
403
 
404
+ for key, value in normalized_items:
405
+ # Track which types of scales were observed (post-normalization)
406
+ if key.endswith("input_scale"):
407
+ loaded_input_scale = True
408
+ if key.endswith("weight_scale"):
409
+ loaded_weight_scale = True
410
+ if key.endswith("k_scale") or key.endswith("v_scale"):
411
+ loaded_kv_scale = True
412
+
413
+ # Filter by layer index if requested
316
414
  if target_layers is not None:
317
415
  parts = key.split(".")
318
-
319
416
  if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
320
417
  continue
321
418
 
419
+ # Copy into parameters or buffers
322
420
  if key in model_params:
421
+ # Ensure dtype compatibility
422
+ if model_params[key].dtype != value.dtype:
423
+ value = value.to(model_params[key].dtype)
323
424
  model_params[key].data.copy_(value)
324
425
  elif key in model_buffers:
426
+ if model_buffers[key].dtype != value.dtype:
427
+ value = value.to(model_buffers[key].dtype)
325
428
  model_buffers[key].data.copy_(value)
326
- elif "kv_scale" in key and determined_format == "quark":
327
- if rbln_quantization.kv_caches == "fp8":
328
- model_params[key.replace("kv_scale", "k_proj.k_scale")].data.copy_(value)
329
- model_params[key.replace("kv_scale", "v_proj.v_scale")].data.copy_(value)
330
- else:
331
- unloaded_keys.append(key)
332
429
  else:
333
430
  unloaded_keys.append(key)
334
431
 
335
432
  if len(unloaded_keys) > 0:
336
433
  logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
337
-
338
434
  if not loaded_input_scale and rbln_quantization.activations == "fp8":
339
435
  raise ValueError(
340
436
  "No input_scale found in the checkpoint. Did you use the correct quantization config? "
@@ -391,16 +487,17 @@ def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) ->
391
487
  """
392
488
 
393
489
  def qlinear_forward(self, inputs: torch.Tensor) -> torch.Tensor:
394
- if inputs.dtype != self.scales.dtype:
395
- raise TypeError(f"Expected input dtype {self.scales.dtype}, but got {inputs.dtype}")
490
+ weight_scale = self.weight_scale
491
+ if inputs.dtype != weight_scale.dtype:
492
+ raise TypeError(f"Expected input dtype {weight_scale.dtype}, but got {inputs.dtype}")
396
493
 
397
494
  w_fp = self.weight.type(inputs.dtype)
398
- w_fp *= self.scales.view(-1, 1)
495
+ w_fp *= weight_scale.view(-1, 1)
399
496
  return F.linear(inputs, w_fp, self.bias)
400
497
 
401
498
  # Convert weight to int8 and add scale parameter
402
499
  layer.weight = Parameter(layer.weight.to(torch.int8), requires_grad=False)
403
- layer.scales = Parameter(torch.ones(layer.out_features, dtype=torch.float32), requires_grad=False)
500
+ layer.weight_scale = Parameter(torch.ones(layer.out_features, 1, dtype=torch.float32), requires_grad=False)
404
501
  layer.forward = lambda inputs: qlinear_forward(layer, inputs)
405
502
 
406
503
  return layer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.3a0
3
+ Version: 0.8.3a2
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,9 +1,9 @@
1
- optimum/rbln/__init__.py,sha256=i3WWddZ0okF5dQN3B_2wfM7NTnQ37lwdD2udzjMRGH8,17140
2
- optimum/rbln/__version__.py,sha256=1jEsQwW1wBFWk7T3YA4ed9DazJ1mgUtA1pZZvNgBZpc,712
1
+ optimum/rbln/__init__.py,sha256=YhaBhcyu6BgoJrprUogLGAmiBaHayvg6Tjm6PpfJETw,17382
2
+ optimum/rbln/__version__.py,sha256=LoGi14U0L2os-fSHKgBIGeByegJLodfXKteGMBVsCEc,712
3
3
  optimum/rbln/configuration_utils.py,sha256=xneqnRWSUVROqpzbTrBACex42-L9zwo3eSjfHjFuhv4,33072
4
4
  optimum/rbln/modeling.py,sha256=0CMQnpVvW9evNrTFHM2XFbNpRY1HkbFzYJ5sRyYFq0o,14293
5
5
  optimum/rbln/modeling_base.py,sha256=gHfqIO6lKT8smkUthUuRHnbITpxHpnDeBPT8iTeasCk,24575
6
- optimum/rbln/diffusers/__init__.py,sha256=cvyJaFRU1sP1WeRjWrxMOm-5vr0c4X-TD8eqQ21XIgc,6990
6
+ optimum/rbln/diffusers/__init__.py,sha256=1tgU_xWA42BmInqu9bBz_5R_E9TGhhK3mI06YlaiTLg,7232
7
7
  optimum/rbln/diffusers/modeling_diffusers.py,sha256=TAuMb7PSMjNwK7mh5ItE_CtAEgYeZKI27XkFFmxjHlQ,19902
8
8
  optimum/rbln/diffusers/configurations/__init__.py,sha256=vMRnPY4s-Uju43xP038D2EA18X_mhy2YfsZVpSU-VoA,1322
9
9
  optimum/rbln/diffusers/configurations/models/__init__.py,sha256=7q95gtgDzCeIBogGw8SLQoHT4Wch7vpLJVF2UQovuoo,567
@@ -35,7 +35,8 @@ optimum/rbln/diffusers/models/transformers/transformer_cosmos.py,sha256=UQ_R7RVJ
35
35
  optimum/rbln/diffusers/models/transformers/transformer_sd3.py,sha256=yF7sS0QvawowpV9hR5GeT8DaE8CCp3mj1njHHd9cKTc,6630
36
36
  optimum/rbln/diffusers/models/unets/__init__.py,sha256=MaICuK9CWjgzejXy8y2NDrphuEq1rkzanF8u45k6O5I,655
37
37
  optimum/rbln/diffusers/models/unets/unet_2d_condition.py,sha256=v3WS9EGKROE_QClXrxC7rmRko1BspAvAbeIfh83LK88,15832
38
- optimum/rbln/diffusers/pipelines/__init__.py,sha256=Ft1i48HP3wVi5t7PpIPNhL-bcxpLfwyZ5kuaTECAx1A,3392
38
+ optimum/rbln/diffusers/pipelines/__init__.py,sha256=r8mu21102cKXdkG1II9tpfpUS6wuyren2oK9y_MptZY,3703
39
+ optimum/rbln/diffusers/pipelines/auto_pipeline.py,sha256=oGZWXfj82w695D2NiYUitgoWiwP2Z4PlgA3q6hoOKww,9502
39
40
  optimum/rbln/diffusers/pipelines/controlnet/__init__.py,sha256=n1Ef22TSeax-kENi_d8K6wGGHSNEo9QkUeygELHgcao,983
40
41
  optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py,sha256=3S9dogIHW8Bqg5kIlCudhCQG-4g3FcdOPEWhBOf7CJA,4059
41
42
  optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py,sha256=G96bh4D9Cu-w4F9gZBQF6wNzhJQv9kvI34ZFsuEDjSw,35714
@@ -195,7 +196,7 @@ optimum/rbln/transformers/models/siglip/modeling_siglip.py,sha256=1TyRaxmhp6mg6U
195
196
  optimum/rbln/transformers/models/t5/__init__.py,sha256=R1Q8Z1vaIdx4rDjeCmm_ZMSgewWaqaI0l93AHwewtew,818
196
197
  optimum/rbln/transformers/models/t5/configuration_t5.py,sha256=nqDbibqykeeWn1TlKk6LmCn-DawTVudMMuBn2c2jds8,1362
197
198
  optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=pdAWBLVknTzbma0Ij-VQ2Qve-frPjxL-AwMyU-zouPY,5123
198
- optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=X_9X4QRhkiiMrwFHv3mzER3yGmF9oQ2U-HdH6jbwVmw,9824
199
+ optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=DlJNrGk35NTBhcp76PEhiyfs5yuUoDWKvMhfe4_puIE,10171
199
200
  optimum/rbln/transformers/models/time_series_transformer/__init__.py,sha256=xJaFWQawlwtv4H5tVFcY1pxLYzjHtMAlLq6nXysdkN8,1243
200
201
  optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py,sha256=MO-T4pcsea4EOmYeeg0tosUH6w76azqIPyV8Em8CMqw,1621
201
202
  optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py,sha256=8orxM-LbShCt2jC8Uyx43cSxWN1CGxamS58pKPjvzxs,17167
@@ -215,7 +216,7 @@ optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=O3o2KzJ8Li3QhB7G
215
216
  optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py,sha256=wHRpGTXL9khYqSkKL1IgA7__6_lt9QpOz9tHumjK7fo,1260
216
217
  optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=EZd3flRUEE38DYtdqEnG70LV7fHhkamRZV51xrVyjYI,1093
217
218
  optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
218
- optimum/rbln/transformers/utils/rbln_quantization.py,sha256=PQY46_Yq_ic6n8F_RsZSumdFNd_NGKHfVNHNxDbVia0,17578
219
+ optimum/rbln/transformers/utils/rbln_quantization.py,sha256=ARngdvRmeVoOphUU3Md9kT6zS5HDrYdEFYljJwaAaio,21020
219
220
  optimum/rbln/utils/__init__.py,sha256=ieDBT2VFTt2E0M4v_POLBpuGW9LxSydpb_DuPd6PQqc,712
220
221
  optimum/rbln/utils/decorator_utils.py,sha256=xu-TrsNi33SRC2a7DBsyoo6-pEQxWKZPZSmM9QlDe2Y,3745
221
222
  optimum/rbln/utils/depreacate_utils.py,sha256=uKxl3ENUCNaZXPnaDQvNxrH8hUIWdBWfZH6BM7ZV__4,385
@@ -226,7 +227,7 @@ optimum/rbln/utils/model_utils.py,sha256=4k5879Kh75m3x_vS4-qOGfqsOiAvc2kdNFFfvsF
226
227
  optimum/rbln/utils/runtime_utils.py,sha256=R6uXDbeJP03-FWdd4vthNe2D4aCra5n12E3WB1ifiGM,7933
227
228
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
228
229
  optimum/rbln/utils/submodule.py,sha256=w5mgPgncI740gVKMu3S-69DGNdUSI0bTZxegQGcZ98Y,5011
229
- optimum_rbln-0.8.3a0.dist-info/METADATA,sha256=pv8AVPfkvMkms_pTvelG637GLOE0DdTIsCfJSLMSjfQ,5299
230
- optimum_rbln-0.8.3a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
231
- optimum_rbln-0.8.3a0.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
232
- optimum_rbln-0.8.3a0.dist-info/RECORD,,
230
+ optimum_rbln-0.8.3a2.dist-info/METADATA,sha256=KAOx0J5beZebrxsAf9AsklRO43eTWaw222WX1iInnpk,5299
231
+ optimum_rbln-0.8.3a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
232
+ optimum_rbln-0.8.3a2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
233
+ optimum_rbln-0.8.3a2.dist-info/RECORD,,