optimum-rbln 0.8.2rc0__py3-none-any.whl → 0.8.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (90) hide show
  1. optimum/rbln/__init__.py +4 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/configuration_utils.py +4 -4
  4. optimum/rbln/diffusers/__init__.py +1 -0
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  22. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
  23. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  24. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  25. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  26. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  27. optimum/rbln/modeling.py +2 -2
  28. optimum/rbln/modeling_base.py +12 -4
  29. optimum/rbln/ops/attn.py +158 -0
  30. optimum/rbln/ops/flash_attn.py +166 -0
  31. optimum/rbln/transformers/__init__.py +2 -0
  32. optimum/rbln/transformers/configuration_generic.py +4 -4
  33. optimum/rbln/transformers/modeling_generic.py +1 -4
  34. optimum/rbln/transformers/modeling_outputs.py +37 -0
  35. optimum/rbln/transformers/models/__init__.py +6 -16
  36. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  37. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  38. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  39. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  43. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  44. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  45. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  46. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
  47. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +101 -91
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
  49. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +296 -986
  51. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  52. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
  53. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  54. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
  55. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +19 -250
  56. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
  57. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  58. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  59. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
  60. optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
  61. optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  64. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
  65. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
  66. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
  67. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
  68. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
  69. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
  70. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  71. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
  72. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
  73. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
  74. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
  75. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
  76. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
  77. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  78. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  79. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  80. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  81. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  82. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
  83. optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
  84. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  85. optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
  86. optimum/rbln/utils/runtime_utils.py +3 -3
  87. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
  88. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
  89. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
  90. {optimum_rbln-0.8.2rc0.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
@@ -126,7 +126,7 @@ class RBLNSiglipVisionModel(RBLNModel):
126
126
  output_attentions: bool = None,
127
127
  output_hidden_states: bool = None,
128
128
  interpolate_pos_encoding: bool = False,
129
- **kwargs: Dict[str, Any],
129
+ **kwargs: Any,
130
130
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
131
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
132
  logger.warning(
@@ -32,3 +32,5 @@ class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
32
32
  This configuration class stores the configuration parameters specific to
33
33
  RBLN-optimized T5 models for conditional text generation tasks.
34
34
  """
35
+
36
+ support_paged_attention = False
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
@@ -17,7 +17,7 @@ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
17
17
  enc_max_seq_len: Optional[int] = None,
18
18
  dec_max_seq_len: Optional[int] = None,
19
19
  num_parallel_samples: Optional[int] = None,
20
- **kwargs: Dict[str, Any],
20
+ **kwargs: Any,
21
21
  ):
22
22
  """
23
23
  Args:
@@ -23,24 +23,20 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from dataclasses import dataclass
27
26
  from pathlib import Path
28
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
29
28
 
30
29
  import rebel
31
30
  import torch
32
31
  from rebel.compile_context import CompileContext
33
- from transformers import (
34
- PretrainedConfig,
35
- TimeSeriesTransformerForPrediction,
36
- TimeSeriesTransformerModel,
37
- )
38
- from transformers.modeling_outputs import ModelOutput, SampleTSPredictionOutput, Seq2SeqTSModelOutput
32
+ from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
33
+ from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
39
34
  from transformers.modeling_utils import no_init_weights
40
35
 
41
36
  from ....configuration_utils import RBLNCompileConfig
42
37
  from ....modeling import RBLNModel
43
38
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
+ from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
44
40
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
45
41
  from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
46
42
 
@@ -113,12 +109,6 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
113
109
  )
114
110
 
115
111
 
116
- @dataclass
117
- class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
118
- last_hidden_states: torch.FloatTensor = None
119
- params: Tuple[torch.FloatTensor] = None
120
-
121
-
122
112
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
113
  """
124
114
  The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
15
+ from typing import Any
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
@@ -36,7 +36,9 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
36
36
  use_attention_mask: bool = None,
37
37
  enc_max_seq_len: int = None,
38
38
  dec_max_seq_len: int = None,
39
- **kwargs: Dict[str, Any],
39
+ kvcache_num_blocks: int = None,
40
+ kvcache_block_size: int = None,
41
+ **kwargs: Any,
40
42
  ):
41
43
  """
42
44
  Args:
@@ -45,6 +47,10 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
45
47
  use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
46
48
  enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
47
49
  dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
50
+ kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
51
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
52
+ kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
53
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
48
54
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
55
 
50
56
  Raises:
@@ -62,3 +68,5 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
62
68
 
63
69
  self.use_attention_mask = use_attention_mask
64
70
  self.use_attention_mask = self.use_attention_mask or False
71
+ self.kvcache_num_blocks = kvcache_num_blocks
72
+ self.kvcache_block_size = kvcache_block_size
@@ -46,7 +46,7 @@ if TYPE_CHECKING:
46
46
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
47
47
  mandatory_members = ["main_input_name"]
48
48
 
49
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
49
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
50
50
  output = super().forward(*args, **kwargs)
51
51
  return BaseModelOutput(last_hidden_state=output)
52
52
 
@@ -253,6 +253,23 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
253
253
 
254
254
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
255
255
 
256
+ @classmethod
257
+ def _update_paged_attention_config(
258
+ cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
259
+ ):
260
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
261
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
262
+
263
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
264
+ raise NotImplementedError(
265
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
266
+ )
267
+
268
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
269
+ raise NotImplementedError(
270
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
271
+ )
272
+
256
273
  @classmethod
257
274
  def _update_rbln_config(
258
275
  cls,
@@ -270,6 +287,8 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
270
287
  if rbln_config.dec_max_seq_len is None:
271
288
  rbln_config.dec_max_seq_len = model_config.max_length
272
289
 
290
+ cls._update_paged_attention_config(model_config, rbln_config)
291
+
273
292
  enc_input_info = [
274
293
  ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
275
294
  ("block_tables", [1], "int16"),
@@ -12,14 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_xlm_roberta import (
16
- RBLNXLMRobertaForSequenceClassificationConfig,
17
- RBLNXLMRobertaModelConfig,
18
- )
19
- from .modeling_xlm_roberta import (
20
- RBLNXLMRobertaForSequenceClassification,
21
- RBLNXLMRobertaModel,
22
- )
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
16
+ from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
23
17
 
24
18
 
25
19
  __all__ = [
@@ -13,10 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import glob
16
+ import json
16
17
  import os
17
18
  from typing import Any, Dict, Optional, Union
18
19
 
19
20
  import torch
21
+ from huggingface_hub import hf_hub_download, list_repo_files
20
22
  from safetensors.torch import load_file
21
23
  from torch.nn import Linear, Parameter
22
24
  from torch.nn import functional as F
@@ -30,21 +32,24 @@ logger = get_logger()
30
32
 
31
33
  class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
32
34
  SUPPORTED_FORMATS = ["rbln"]
33
- SUPPORTED_WEIGHTS = ["int4", "fp16"]
34
- SUPPORTED_ACTIVATIONS = ["fp16"]
35
-
36
- # The RBLN_QUANT_BITS environment variable defines the precision of each layer during the graph compilation process.
37
- # It specifies the quantization bit depth. For instance, setting RBLN_QUANT_BITS=4 will apply 4-bit precision for quantization.
35
+ SUPPORTED_WEIGHTS = ["int4", "fp8", "fp16"]
36
+ SUPPORTED_ACTIVATIONS = ["fp8", "fp16"]
37
+ SUPPORTED_KVCACHES = ["fp8", "fp16"]
38
38
  RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
39
39
 
40
40
  def __init__(
41
41
  self,
42
42
  format: Optional[str] = None,
43
- precision: Optional[str] = None,
44
43
  weights: Optional[str] = None,
45
44
  activations: Optional[str] = None,
45
+ kv_caches: Optional[str] = None,
46
+ *,
47
+ precision: Optional[str] = None,
46
48
  ):
47
- self.format = format
49
+ self.format = format or "rbln"
50
+ if self.format not in self.SUPPORTED_FORMATS:
51
+ raise ValueError(f"Invalid format: {self.format}, supported formats are: {self.SUPPORTED_FORMATS}")
52
+
48
53
  if precision is not None:
49
54
  logger.warning("The `precision` argument is deprecated. Use `weights` and `activations` instead.")
50
55
  if any(precision_arg is not None for precision_arg in (weights, activations)):
@@ -58,6 +63,8 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
58
63
 
59
64
  self.weights = weights or "fp16"
60
65
  self.activations = activations or "fp16"
66
+ self.kv_caches = kv_caches or "fp16"
67
+
61
68
  self._validate()
62
69
 
63
70
  def _validate(self):
@@ -69,27 +76,49 @@ class RBLNQuantizationConfig(RBLNSerializableConfigProtocol):
69
76
  raise ValueError(
70
77
  f"Invalid activations: {self.activations}, supported activations are: {self.SUPPORTED_ACTIVATIONS}"
71
78
  )
79
+ if self.kv_caches not in self.SUPPORTED_KVCACHES:
80
+ raise ValueError(
81
+ f"Invalid kv_caches: {self.kv_caches}, supported kv_caches are: {self.SUPPORTED_KVCACHES}"
82
+ )
72
83
  if self.weights == "fp16" and self.activations == "fp16":
73
- raise ValueError("weights and activations cannot be both fp16. It is meaningless.")
84
+ raise ValueError("weights and activations of QuantizationConfig cannot be both fp16. It is meaningless.")
74
85
 
75
86
  def _prepare_for_serialization(self) -> Dict[str, Any]:
76
87
  return {
77
88
  "format": self.format,
78
89
  "weights": self.weights,
79
90
  "activations": self.activations,
91
+ "kv_caches": self.kv_caches,
80
92
  }
81
93
 
82
94
  def maybe_set_quantization_env(self):
83
- quant_bits = None
84
95
  if self.weights == "int4":
85
- quant_bits = "4"
86
- os.environ[self.RBLN_QUANT_BITS_ENV] = quant_bits
96
+ os.environ[self.RBLN_QUANT_BITS_ENV] = "4"
87
97
 
88
98
  def maybe_reset_quantization_env(self):
89
99
  if self.RBLN_QUANT_BITS_ENV in os.environ:
90
100
  os.environ.pop(self.RBLN_QUANT_BITS_ENV)
91
101
 
92
102
 
103
+ class QuantizedLayerFactory:
104
+ def __init__(self, quantization_config: RBLNQuantizationConfig):
105
+ self.quantization_config = quantization_config
106
+
107
+ def create_linear(self, layer: Linear) -> Linear:
108
+ if self.quantization_config.weights == "int4":
109
+ return self.create_qlinear(layer)
110
+ elif self.quantization_config.weights == "fp8":
111
+ return self.create_fp8linear(layer)
112
+ else:
113
+ raise ValueError(f"Invalid quantization weights: {self.quantization_config.weights}")
114
+
115
+ def create_qlinear(self, layer: Linear) -> Linear:
116
+ return create_qlinear(layer, self.quantization_config)
117
+
118
+ def create_fp8linear(self, layer: Linear) -> Linear:
119
+ return create_fp8linear(layer, self.quantization_config)
120
+
121
+
93
122
  # Constants
94
123
  QUANTIZED_WEIGHTS = {
95
124
  "q_proj",
@@ -111,64 +140,60 @@ def prepare_model_for_quantization(
111
140
  cache_dir: Optional[str] = None,
112
141
  force_download: bool = False,
113
142
  local_files_only: bool = False,
143
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
114
144
  ) -> torch.nn.Module:
115
145
  """
116
146
  Prepare the model for quantization by updating specified linear layers to quantized (qlinear) layers.
117
147
  """
118
- update_layers_to_quantize(model)
119
- load_weights(
120
- model,
148
+
149
+ # 1. Load weight files and safetensors.index.json
150
+ safetensor_files, index_data = load_weight_files_and_index(
121
151
  model_id,
122
- n_layer,
123
152
  use_auth_token=use_auth_token,
124
153
  revision=revision,
125
154
  cache_dir=cache_dir,
126
155
  force_download=force_download,
127
156
  local_files_only=local_files_only,
128
157
  )
129
- return model
130
158
 
159
+ # 2. Determine format from safetensors.index.json
160
+ determined_format = determine_format_from_index(index_data)
131
161
 
132
- def update_layers_to_quantize(module: torch.nn.Module) -> None:
133
- """
134
- Updates specified linear layers to quantized (qlinear) layers in the given module.
135
- """
136
-
137
- logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
138
- processed_layers = []
162
+ # 3. Update linear layers based on the determined format
163
+ update_layers_to_quantize(model, rbln_quantization)
139
164
 
140
- for name, layer in module.named_modules():
141
- if is_target_for_qlinear_replacement(name, layer):
142
- parent_module, layer_name = get_parent_and_child(module, name)
143
- setattr(parent_module, layer_name, create_qlinear(layer))
144
- processed_layers.append(name)
165
+ # 4. Load weights into model parameters
166
+ load_weights_from_files(
167
+ model,
168
+ safetensor_files,
169
+ n_layer,
170
+ rbln_quantization=rbln_quantization,
171
+ determined_format=determined_format,
172
+ )
145
173
 
146
- if processed_layers:
147
- logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
174
+ return model
148
175
 
149
176
 
150
- def load_weights(
151
- model,
152
- model_id,
153
- n_layer=None,
154
- use_auth_token=None,
155
- revision=None,
156
- cache_dir=None,
157
- force_download=False,
158
- local_files_only=False,
159
- ):
177
+ def load_weight_files_and_index(
178
+ model_id: str,
179
+ use_auth_token: Optional[Union[bool, str]] = None,
180
+ revision: Optional[str] = None,
181
+ cache_dir: Optional[str] = None,
182
+ force_download: bool = False,
183
+ local_files_only: bool = False,
184
+ ) -> tuple[list[str], Optional[Dict]]:
160
185
  """
161
186
  Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
162
187
  """
163
-
164
- model_params = dict(model.named_parameters(recurse=True))
165
- model_buffers = dict(model.named_buffers(recurse=True))
188
+ index_data = None
166
189
 
167
190
  if os.path.isdir(model_id):
168
191
  safetensor_files = glob.glob(f"{model_id}/*.safetensors")
192
+ index_path = os.path.join(model_id, "model.safetensors.index.json")
193
+ if os.path.exists(index_path):
194
+ with open(index_path, "r") as f:
195
+ index_data = json.load(f)
169
196
  else:
170
- from huggingface_hub import hf_hub_download, list_repo_files
171
-
172
197
  try:
173
198
  # List all files in the repository
174
199
  repo_files = list_repo_files(model_id, revision=revision, token=use_auth_token)
@@ -188,6 +213,20 @@ def load_weights(
188
213
  local_files_only=local_files_only,
189
214
  )
190
215
  safetensor_files.append(downloaded_file)
216
+ elif file == "model.safetensors.index.json":
217
+ # Download the index file
218
+ index_file = hf_hub_download(
219
+ repo_id=model_id,
220
+ filename=file,
221
+ revision=revision,
222
+ token=use_auth_token,
223
+ cache_dir=cache_dir,
224
+ force_download=force_download,
225
+ local_files_only=local_files_only,
226
+ )
227
+
228
+ with open(index_file, "r") as f:
229
+ index_data = json.load(f)
191
230
  except Exception as e:
192
231
  logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
193
232
  raise e
@@ -195,12 +234,85 @@ def load_weights(
195
234
  if not safetensor_files:
196
235
  raise FileNotFoundError(f"No safetensors files found for model_id: {model_id}")
197
236
 
237
+ return safetensor_files, index_data
238
+
239
+
240
+ def determine_format_from_index(index_data: Optional[Dict]) -> str:
241
+ """
242
+ Determine the quantization format from safetensors.index.json data.
243
+
244
+ Args:
245
+ index_data: The loaded safetensors.index.json content
246
+
247
+ Returns:
248
+ str: The determined format string
249
+ """
250
+ if index_data is None:
251
+ raise ValueError("safetensors.index.json not found")
252
+ if "weight_map" not in index_data:
253
+ raise ValueError("weight_map not found in safetensors.index.json")
254
+
255
+ if any("self_attn.k_proj.k_scale" in key for key in index_data["weight_map"]):
256
+ return "tensorrt"
257
+ elif any("self_attn.kv_scale" in key for key in index_data["weight_map"]):
258
+ return "quark"
259
+ elif any("weight_scale" in key or "input_scale" in key for key in index_data["weight_map"]):
260
+ return "default"
261
+ else:
262
+ raise ValueError("Unknown quantization format of the index data of weight map.")
263
+
264
+
265
+ def update_layers_to_quantize(
266
+ module: torch.nn.Module,
267
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
268
+ ) -> None:
269
+ """
270
+ Updates specified linear layers to quantized (qlinear) layers in the given module.
271
+ """
272
+
273
+ processed_layers = []
274
+ quantized_layer_factory = QuantizedLayerFactory(rbln_quantization)
275
+
276
+ for name, layer in module.named_modules():
277
+ if is_target_for_qlinear_replacement(name, layer):
278
+ parent_module, layer_name = get_parent_and_child(module, name)
279
+ setattr(parent_module, layer_name, quantized_layer_factory.create_linear(layer))
280
+ processed_layers.append(name)
281
+
282
+ if processed_layers:
283
+ logger.debug(f"Updated the following linear layers to quantized layers:\n {{{', '.join(processed_layers)}}}")
284
+
285
+
286
+ def load_weights_from_files(
287
+ model: torch.nn.Module,
288
+ safetensor_files: list[str],
289
+ n_layer: Optional[int] = None,
290
+ rbln_quantization: Optional[RBLNQuantizationConfig] = None,
291
+ determined_format: Optional[str] = None,
292
+ ):
293
+ """
294
+ Load safetensor file data directly into the model from provided safetensor files,
295
+ filtering by layer if n_layer is provided.
296
+ """
297
+
298
+ model_params = dict(model.named_parameters(recurse=True))
299
+ model_buffers = dict(model.named_buffers(recurse=True))
300
+
198
301
  target_layers = list(range(n_layer)) if n_layer is not None else None
199
302
 
200
303
  unloaded_keys = []
304
+ loaded_input_scale = False
305
+ loaded_kv_scale = False
306
+ loaded_weight_scale = False
307
+
201
308
  for safetensor_file in safetensor_files:
202
309
  file_data = load_file(safetensor_file)
310
+
203
311
  for key, value in file_data.items():
312
+ loaded_input_scale = loaded_input_scale or "input_scale" in key
313
+ loaded_weight_scale = loaded_weight_scale or "weight_scale" in key
314
+ loaded_kv_scale = loaded_kv_scale or any(scale in key for scale in ["kv_scale", "k_scale", "v_scale"])
315
+
204
316
  if target_layers is not None:
205
317
  parts = key.split(".")
206
318
 
@@ -211,12 +323,38 @@ def load_weights(
211
323
  model_params[key].data.copy_(value)
212
324
  elif key in model_buffers:
213
325
  model_buffers[key].data.copy_(value)
326
+ elif "kv_scale" in key and determined_format == "quark":
327
+ if rbln_quantization.kv_caches == "fp8":
328
+ model_params[key.replace("kv_scale", "k_proj.k_scale")].data.copy_(value)
329
+ model_params[key.replace("kv_scale", "v_proj.v_scale")].data.copy_(value)
330
+ else:
331
+ unloaded_keys.append(key)
214
332
  else:
215
333
  unloaded_keys.append(key)
216
334
 
217
335
  if len(unloaded_keys) > 0:
218
336
  logger.warning(f"There are unexpected parameters/buffers on the checkpoint: {unloaded_keys}")
219
337
 
338
+ if not loaded_input_scale and rbln_quantization.activations == "fp8":
339
+ raise ValueError(
340
+ "No input_scale found in the checkpoint. Did you use the correct quantization config? "
341
+ "If you are using fp8 quantization, you need to use the correct quantization config."
342
+ )
343
+ if not loaded_weight_scale and rbln_quantization.weights == "fp8":
344
+ raise ValueError(
345
+ "No weight_scale found in the checkpoint. Did you use the correct quantization config? "
346
+ "If you are using fp8 quantization, you need to use the correct quantization config."
347
+ )
348
+ if not loaded_kv_scale and rbln_quantization.kv_caches == "fp8":
349
+ raise ValueError(
350
+ "No kv_scale found in the checkpoint. Did you use the correct quantization config? "
351
+ "If you are using fp8 quantization, you need to use the correct quantization config."
352
+ )
353
+ if loaded_kv_scale and rbln_quantization.kv_caches != "fp8":
354
+ logger.warning(
355
+ "kv_scale found in the checkpoint, but kv_caches of quantization config is not fp8. Ignoring kv_scale."
356
+ )
357
+
220
358
 
221
359
  def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
222
360
  """
@@ -225,6 +363,10 @@ def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -
225
363
  return layer_name.split(".")[-1] in QUANTIZED_WEIGHTS and isinstance(layer, torch.nn.Linear)
226
364
 
227
365
 
366
+ def is_target_for_adding_kv_scales(layer_name: str) -> bool:
367
+ return layer_name.split(".")[-1] in ["self_attn"]
368
+
369
+
228
370
  def get_parent_and_child(module: torch.nn.Module, full_name: str) -> tuple:
229
371
  """
230
372
  Splits the full layer name to retrieve the parent module and the child layer.
@@ -243,7 +385,7 @@ def access_attribute(obj: Any, attributes: list[str]) -> Any:
243
385
  return obj
244
386
 
245
387
 
246
- def create_qlinear(layer: Linear) -> Linear:
388
+ def create_qlinear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
247
389
  """
248
390
  Converts a standard linear layer to a quantized linear (qlinear) layer with a custom forward pass.
249
391
  """
@@ -262,3 +404,64 @@ def create_qlinear(layer: Linear) -> Linear:
262
404
  layer.forward = lambda inputs: qlinear_forward(layer, inputs)
263
405
 
264
406
  return layer
407
+
408
+
409
+ def create_fp8linear(layer: Linear, rbln_quantization: RBLNQuantizationConfig) -> Linear:
410
+ """
411
+ Converts a standard linear layer to a fp8 linear layer with a custom forward pass.
412
+ """
413
+
414
+ def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
415
+ finfo = torch.finfo(torch.float8_e4m3fn)
416
+ qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
417
+ return qweight
418
+
419
+ def fp8_gemm(A: torch.Tensor, A_scale, B: torch.Tensor, B_scale, bias, out_dtype: torch.dtype):
420
+ A = A.type(out_dtype)
421
+ B = B.type(out_dtype)
422
+
423
+ if A_scale is not None:
424
+ A *= A_scale
425
+ if B_scale is not None:
426
+ B *= B_scale.to(out_dtype)
427
+
428
+ output = torch.nn.functional.linear(A, B, bias=bias)
429
+ return output
430
+
431
+ def fp8linear_forward(self, x: torch.Tensor) -> torch.Tensor:
432
+ if self.input_scale:
433
+ input = static_per_tensor_quantize(x, self.input_scale)
434
+ else:
435
+ input = x
436
+
437
+ if self.weight_scale:
438
+ # broadcast weight_scale to vector
439
+ weight_scale = self.weight_scale.broadcast_to(self.weight.shape[-1:])
440
+ else:
441
+ weight_scale = None
442
+ output = fp8_gemm(
443
+ A=input,
444
+ A_scale=self.input_scale,
445
+ B=self.weight,
446
+ B_scale=weight_scale,
447
+ bias=self.bias,
448
+ out_dtype=x.dtype,
449
+ )
450
+
451
+ return output
452
+
453
+ layer.weight = Parameter(layer.weight.to(torch.float8_e4m3fn), requires_grad=False)
454
+ layer.weight_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
455
+
456
+ if rbln_quantization.activations == "fp8":
457
+ layer.input_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
458
+ else:
459
+ layer.input_scale = None
460
+
461
+ if rbln_quantization.kv_caches == "fp8":
462
+ layer.k_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
463
+ layer.v_scale = Parameter(torch.tensor(1, dtype=torch.float32), requires_grad=False)
464
+
465
+ layer.forward = lambda inputs: fp8linear_forward(layer, inputs)
466
+
467
+ return layer
@@ -14,7 +14,7 @@
14
14
 
15
15
  import re
16
16
  import threading
17
- from typing import Any, Dict, List, Optional, Union
17
+ from typing import Any, List, Optional, Union
18
18
 
19
19
  import rebel
20
20
  import torch
@@ -94,7 +94,7 @@ class RBLNPytorchRuntime:
94
94
  def __call__(self, *args: Any, **kwds: Any) -> Any:
95
95
  return self.forward(*args, **kwds)
96
96
 
97
- def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
97
+ def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
98
98
  # filtering useless args or kwarg such as None.
99
99
  args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
100
100
  kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor) or kwarg[0] == "out", kwargs.items()))
@@ -142,7 +142,7 @@ class UnavailableRuntime:
142
142
  """Returns an iterator with self as the only item."""
143
143
  return iter([self])
144
144
 
145
- def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
145
+ def forward(self, *args: List["torch.Tensor"], **kwargs: "torch.Tensor"):
146
146
  """Raises a detailed RuntimeError explaining why inference cannot be performed."""
147
147
  raise RuntimeError(
148
148
  "Cannot perform inference: RBLN runtime is not available.\n\n"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.8.2rc0
3
+ Version: 0.8.3a0
4
4
  Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai