optimum-rbln 0.8.3rc0__py3-none-any.whl → 0.8.4a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +15 -0
- optimum/rbln/modeling.py +1 -4
- optimum/rbln/modeling_base.py +20 -6
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +6 -3
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +7 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +12 -31
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +2 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +47 -31
- {optimum_rbln-0.8.3rc0.dist-info → optimum_rbln-0.8.4a0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.3rc0.dist-info → optimum_rbln-0.8.4a0.dist-info}/RECORD +14 -14
- {optimum_rbln-0.8.3rc0.dist-info → optimum_rbln-0.8.4a0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.3rc0.dist-info → optimum_rbln-0.8.4a0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.8.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 8,
|
|
31
|
+
__version__ = version = '0.8.4a0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 8, 4, 'a0')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -476,6 +476,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
476
476
|
non_save_attributes = [
|
|
477
477
|
"_frozen",
|
|
478
478
|
"_runtime_options",
|
|
479
|
+
"torch_dtype",
|
|
479
480
|
"npu",
|
|
480
481
|
"tensor_parallel_size",
|
|
481
482
|
"create_runtimes",
|
|
@@ -566,6 +567,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
566
567
|
tensor_parallel_size: Optional[int] = None,
|
|
567
568
|
timeout: Optional[int] = None,
|
|
568
569
|
optimum_rbln_version: Optional[str] = None,
|
|
570
|
+
_torch_dtype: Optional[str] = None,
|
|
569
571
|
_compile_cfgs: List[RBLNCompileConfig] = [],
|
|
570
572
|
**kwargs: Any,
|
|
571
573
|
):
|
|
@@ -583,6 +585,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
583
585
|
tensor_parallel_size (Optional[int]): Size for tensor parallelism to distribute the model across devices.
|
|
584
586
|
timeout (Optional[int]): The timeout for the runtime in seconds. If it isn't provided, it will be set to 60 by default.
|
|
585
587
|
optimum_rbln_version (Optional[str]): The optimum-rbln version used for this configuration.
|
|
588
|
+
_torch_dtype (Optional[str]): The data type to use for the model.
|
|
586
589
|
_compile_cfgs (List[RBLNCompileConfig]): List of compilation configurations for the model.
|
|
587
590
|
**kwargs: Additional keyword arguments.
|
|
588
591
|
|
|
@@ -610,6 +613,7 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
610
613
|
self.npu = npu
|
|
611
614
|
self.tensor_parallel_size = tensor_parallel_size
|
|
612
615
|
|
|
616
|
+
self._torch_dtype = _torch_dtype or "float32"
|
|
613
617
|
self.optimum_rbln_version = optimum_rbln_version
|
|
614
618
|
if self.optimum_rbln_version is None:
|
|
615
619
|
self.optimum_rbln_version = __version__
|
|
@@ -639,6 +643,17 @@ class RBLNModelConfig(RBLNSerializableConfigProtocol):
|
|
|
639
643
|
|
|
640
644
|
raise ValueError(f"Unexpected arguments: {kwargs.keys()}")
|
|
641
645
|
|
|
646
|
+
@property
|
|
647
|
+
def torch_dtype(self):
|
|
648
|
+
return getattr(torch, self._torch_dtype)
|
|
649
|
+
|
|
650
|
+
@torch_dtype.setter
|
|
651
|
+
def torch_dtype(self, torch_dtype: Union[str, torch.dtype]):
|
|
652
|
+
if isinstance(torch_dtype, torch.dtype):
|
|
653
|
+
torch_dtype = RBLNCompileConfig.normalize_dtype(torch_dtype)
|
|
654
|
+
|
|
655
|
+
self._torch_dtype = torch_dtype
|
|
656
|
+
|
|
642
657
|
@property
|
|
643
658
|
def rbln_model_cls_name(self) -> str:
|
|
644
659
|
return self.__class__.__name__[:-6]
|
optimum/rbln/modeling.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, get_args, ge
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
21
21
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|
22
|
-
from transformers import
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput
|
|
24
24
|
|
|
25
25
|
from .configuration_utils import DEFAULT_COMPILED_MODEL_NAME, RBLNModelConfig
|
|
@@ -119,9 +119,6 @@ class RBLNModel(RBLNBaseModel):
|
|
|
119
119
|
# Save configs
|
|
120
120
|
if config is None:
|
|
121
121
|
config = model.config
|
|
122
|
-
# remote_config
|
|
123
|
-
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
|
124
|
-
config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
|
|
125
122
|
|
|
126
123
|
if hasattr(model, "can_generate") and model.can_generate():
|
|
127
124
|
import json
|
optimum/rbln/modeling_base.py
CHANGED
|
@@ -34,7 +34,7 @@ from .utils.submodule import SubModulesMixin
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
|
-
from transformers import PreTrainedModel
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
38
38
|
|
|
39
39
|
logger = get_logger(__name__)
|
|
40
40
|
|
|
@@ -53,6 +53,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
53
53
|
config_class = AutoConfig
|
|
54
54
|
config_name = "config.json"
|
|
55
55
|
hf_library_name = "transformers"
|
|
56
|
+
_supports_non_fp32 = False
|
|
56
57
|
|
|
57
58
|
def __init__(
|
|
58
59
|
self,
|
|
@@ -91,7 +92,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
91
92
|
|
|
92
93
|
self.device = torch.device("cpu")
|
|
93
94
|
self.training = False
|
|
94
|
-
self.dtype =
|
|
95
|
+
self.dtype = rbln_config.torch_dtype
|
|
95
96
|
|
|
96
97
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
|
97
98
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
|
@@ -400,8 +401,21 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
400
401
|
return compiled_model
|
|
401
402
|
|
|
402
403
|
@classmethod
|
|
403
|
-
def update_rbln_config(
|
|
404
|
-
|
|
404
|
+
def update_rbln_config(
|
|
405
|
+
cls,
|
|
406
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
407
|
+
model: "PreTrainedModel",
|
|
408
|
+
model_config: "PretrainedConfig",
|
|
409
|
+
rbln_config: RBLNModelConfig,
|
|
410
|
+
) -> RBLNModelConfig:
|
|
411
|
+
rbln_config.torch_dtype = model.dtype
|
|
412
|
+
if not cls._supports_non_fp32 and rbln_config.torch_dtype != torch.float32:
|
|
413
|
+
raise NotImplementedError(
|
|
414
|
+
f"Currently, {cls.__name__} does not support non-fp32 dtype. Please use float32 dtype."
|
|
415
|
+
)
|
|
416
|
+
rbln_config = cls._update_rbln_config(
|
|
417
|
+
preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
|
|
418
|
+
)
|
|
405
419
|
rbln_config.freeze()
|
|
406
420
|
if rbln_config.rbln_model_cls_name != cls.__name__:
|
|
407
421
|
raise NameError(
|
|
@@ -444,12 +458,12 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
444
458
|
|
|
445
459
|
# This method mimics the interface of torch.nn.Module.parameters()
|
|
446
460
|
# specifically for code that uses `next(model.parameters())` to infer
|
|
447
|
-
# the device or dtype. It yields a single dummy tensor on CPU with
|
|
461
|
+
# the device or dtype. It yields a single dummy tensor on CPU with model dtype.
|
|
448
462
|
|
|
449
463
|
# Warning:
|
|
450
464
|
# This does NOT yield the actual model parameters used by the RBLN runtime.
|
|
451
465
|
# Code relying on iterating through all model parameters will not work as expected.
|
|
452
|
-
yield torch.tensor([1.0], dtype=
|
|
466
|
+
yield torch.tensor([1.0], dtype=self.dtype, device=torch.device("cpu"))
|
|
453
467
|
|
|
454
468
|
def __call__(self, *args, **kwargs):
|
|
455
469
|
return self.forward(*args, **kwargs)
|
|
@@ -1066,7 +1066,7 @@ class RotaryEmbedding(nn.Module):
|
|
|
1066
1066
|
rope_type = "default"
|
|
1067
1067
|
|
|
1068
1068
|
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1069
|
-
cache_position = torch.arange(0, max_seq_len_cached
|
|
1069
|
+
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1070
1070
|
cache_position_expanded = cache_position[:, None]
|
|
1071
1071
|
|
|
1072
1072
|
if rope_type == "dynamic":
|
|
@@ -1085,8 +1085,8 @@ class RotaryEmbedding(nn.Module):
|
|
|
1085
1085
|
|
|
1086
1086
|
def forward(self, x, seq_len):
|
|
1087
1087
|
return (
|
|
1088
|
-
self._cos_cached[:seq_len].to(dtype=
|
|
1089
|
-
self._sin_cached[:seq_len].to(dtype=
|
|
1088
|
+
self._cos_cached[:seq_len].to(dtype=torch.float32),
|
|
1089
|
+
self._sin_cached[:seq_len].to(dtype=torch.float32),
|
|
1090
1090
|
)
|
|
1091
1091
|
|
|
1092
1092
|
|
|
@@ -1116,8 +1116,11 @@ def rotate_half(x):
|
|
|
1116
1116
|
|
|
1117
1117
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
1118
1118
|
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
1119
|
+
dtype = q.dtype
|
|
1119
1120
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1120
1121
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1122
|
+
q_embed = q_embed.to(dtype)
|
|
1123
|
+
k_embed = k_embed.to(dtype)
|
|
1121
1124
|
return q_embed, k_embed
|
|
1122
1125
|
|
|
1123
1126
|
|
|
@@ -317,7 +317,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
317
317
|
|
|
318
318
|
# Initialize attention mask for chunked processing
|
|
319
319
|
chunked_attention_mask = (
|
|
320
|
-
torch.zeros(
|
|
320
|
+
torch.zeros(
|
|
321
|
+
1,
|
|
322
|
+
1,
|
|
323
|
+
self.rbln_config.prefill_chunk_size,
|
|
324
|
+
self.rbln_config.max_seq_len,
|
|
325
|
+
dtype=self.rbln_config.torch_dtype,
|
|
326
|
+
)
|
|
321
327
|
if self.rbln_config.use_attention_mask
|
|
322
328
|
else None
|
|
323
329
|
)
|
|
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
|
19
19
|
import rebel
|
|
20
20
|
import torch
|
|
21
21
|
from rebel.compile_context import CompileContext
|
|
22
|
-
from transformers import
|
|
22
|
+
from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
23
23
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
24
24
|
from transformers.modeling_utils import no_init_weights
|
|
25
25
|
|
|
@@ -33,7 +33,7 @@ from ...modeling_attention_utils import (
|
|
|
33
33
|
validate_sliding_window,
|
|
34
34
|
)
|
|
35
35
|
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
-
from ...utils.rbln_quantization import
|
|
36
|
+
from ...utils.rbln_quantization import get_quantized_model
|
|
37
37
|
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
38
|
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
39
|
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
@@ -72,6 +72,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
72
72
|
auto_model_class = AutoModel
|
|
73
73
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
74
74
|
_use_rotary_emb = True
|
|
75
|
+
_supports_non_fp32 = True
|
|
75
76
|
|
|
76
77
|
def __post_init__(self, **kwargs):
|
|
77
78
|
if self.rbln_config.use_inputs_embeds:
|
|
@@ -86,10 +87,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
86
87
|
def setup_runtime(self):
|
|
87
88
|
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
88
89
|
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
89
|
-
dec_attn_mask = torch.zeros(
|
|
90
|
-
|
|
91
|
-
)
|
|
92
|
-
out_buffers = [torch.empty(self.prefill_output_size, dtype=torch.float32, device="cpu")]
|
|
90
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
91
|
+
out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
|
|
93
92
|
|
|
94
93
|
common_kwargs = {
|
|
95
94
|
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
@@ -143,35 +142,17 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
143
142
|
):
|
|
144
143
|
kwargs = cls.update_kwargs(kwargs)
|
|
145
144
|
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
model_id,
|
|
149
|
-
use_auth_token=use_auth_token,
|
|
150
|
-
revision=revision,
|
|
151
|
-
force_download=force_download,
|
|
152
|
-
cache_dir=cache_dir,
|
|
153
|
-
trust_remote_code=trust_remote_code,
|
|
154
|
-
**kwargs,
|
|
155
|
-
)
|
|
156
|
-
if config.torch_dtype == torch.bfloat16:
|
|
157
|
-
# FIXME: bfloat16 is not supported by rebel-compiler
|
|
158
|
-
config.torch_dtype = torch.float32
|
|
159
|
-
|
|
160
|
-
with no_init_weights():
|
|
161
|
-
model = cls.auto_model_class.from_config(config)
|
|
162
|
-
|
|
163
|
-
model = prepare_model_for_quantization(
|
|
164
|
-
model,
|
|
145
|
+
return get_quantized_model(
|
|
146
|
+
cls.auto_model_class,
|
|
165
147
|
model_id,
|
|
166
|
-
kwargs.get("num_hidden_layers"),
|
|
167
148
|
use_auth_token=use_auth_token,
|
|
168
149
|
revision=revision,
|
|
169
150
|
cache_dir=cache_dir,
|
|
170
151
|
force_download=force_download,
|
|
171
152
|
local_files_only=local_files_only,
|
|
172
153
|
rbln_quantization=rbln_config.quantization,
|
|
154
|
+
**kwargs,
|
|
173
155
|
)
|
|
174
|
-
return model
|
|
175
156
|
|
|
176
157
|
def __getattr__(self, __name: str) -> Any:
|
|
177
158
|
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
@@ -365,7 +346,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
365
346
|
|
|
366
347
|
input_info = []
|
|
367
348
|
if rbln_config.use_inputs_embeds:
|
|
368
|
-
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size],
|
|
349
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
|
|
369
350
|
else:
|
|
370
351
|
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
371
352
|
|
|
@@ -384,16 +365,16 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
384
365
|
|
|
385
366
|
if rbln_config.use_attention_mask:
|
|
386
367
|
if rbln_config.use_position_ids:
|
|
387
|
-
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len],
|
|
368
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
|
|
388
369
|
else:
|
|
389
370
|
input_info.append(
|
|
390
|
-
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len],
|
|
371
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
|
|
391
372
|
)
|
|
392
373
|
|
|
393
374
|
if rbln_config.use_position_ids:
|
|
394
375
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
395
376
|
|
|
396
|
-
kvcache_dtype =
|
|
377
|
+
kvcache_dtype = rbln_config.torch_dtype
|
|
397
378
|
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
398
379
|
kvcache_dtype = "float8_e4m3fn"
|
|
399
380
|
|
|
@@ -345,6 +345,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
345
345
|
"""
|
|
346
346
|
|
|
347
347
|
_decoder_wrapper_cls = Gemma3ForCausalLMWrapper
|
|
348
|
+
_supports_non_fp32 = False
|
|
348
349
|
|
|
349
350
|
def setup_runtime(self):
|
|
350
351
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
@@ -14,18 +14,23 @@
|
|
|
14
14
|
|
|
15
15
|
import glob
|
|
16
16
|
import os
|
|
17
|
-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
20
|
from huggingface_hub import hf_hub_download, list_repo_files
|
|
21
21
|
from safetensors.torch import load_file
|
|
22
22
|
from torch.nn import Linear, Parameter
|
|
23
23
|
from torch.nn import functional as F
|
|
24
|
+
from transformers import AutoConfig
|
|
25
|
+
from transformers.modeling_utils import get_state_dict_dtype, no_init_weights
|
|
24
26
|
|
|
25
27
|
from ...configuration_utils import RBLNSerializableConfigProtocol
|
|
26
28
|
from ...utils.logging import get_logger
|
|
27
29
|
|
|
28
30
|
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers.models.auto.modeling_auto import _BaseAutoModelClass
|
|
33
|
+
|
|
29
34
|
logger = get_logger()
|
|
30
35
|
|
|
31
36
|
|
|
@@ -138,22 +143,31 @@ class QuantizedLayerFactory:
|
|
|
138
143
|
return create_fp8linear(layer, self.quantization_config)
|
|
139
144
|
|
|
140
145
|
|
|
141
|
-
def
|
|
142
|
-
|
|
146
|
+
def get_quantized_model(
|
|
147
|
+
hf_auto_model_class: Type["_BaseAutoModelClass"],
|
|
143
148
|
model_id: str,
|
|
144
|
-
n_layer: Optional[int] = None,
|
|
145
149
|
use_auth_token: Optional[Union[bool, str]] = None,
|
|
146
150
|
revision: Optional[str] = None,
|
|
147
151
|
cache_dir: Optional[str] = None,
|
|
148
152
|
force_download: bool = False,
|
|
149
153
|
local_files_only: bool = False,
|
|
150
154
|
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
151
|
-
|
|
155
|
+
**kwargs,
|
|
156
|
+
):
|
|
152
157
|
"""
|
|
153
|
-
|
|
158
|
+
Get a quantized model from a model class and model id.
|
|
154
159
|
"""
|
|
160
|
+
# torch_dtype should not be passed to AutoConfig.from_pretrained
|
|
161
|
+
# since it doesn't support 'auto'
|
|
162
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
|
163
|
+
if torch_dtype is not None:
|
|
164
|
+
logger.warning(
|
|
165
|
+
"torch_dtype is not supported for quantized models. "
|
|
166
|
+
"It will be ignored and the dtype of the model will be determined by the weights."
|
|
167
|
+
)
|
|
168
|
+
torch_dtype = None
|
|
155
169
|
|
|
156
|
-
#
|
|
170
|
+
# get paths of safetensors files in the model repo
|
|
157
171
|
safetensor_files = load_weight_files(
|
|
158
172
|
model_id,
|
|
159
173
|
use_auth_token=use_auth_token,
|
|
@@ -163,17 +177,31 @@ def prepare_model_for_quantization(
|
|
|
163
177
|
local_files_only=local_files_only,
|
|
164
178
|
)
|
|
165
179
|
|
|
166
|
-
#
|
|
167
|
-
|
|
180
|
+
# load safetensors files into memory
|
|
181
|
+
safetensors = [load_file(safetensor_file) for safetensor_file in safetensor_files]
|
|
182
|
+
|
|
183
|
+
# get the dtype of the model from the first safetensor file
|
|
184
|
+
torch_dtype = get_state_dict_dtype(safetensors[0])
|
|
168
185
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
186
|
+
config = AutoConfig.from_pretrained(
|
|
187
|
+
model_id,
|
|
188
|
+
use_auth_token=use_auth_token,
|
|
189
|
+
revision=revision,
|
|
190
|
+
cache_dir=cache_dir,
|
|
191
|
+
force_download=force_download,
|
|
192
|
+
local_files_only=local_files_only,
|
|
193
|
+
**kwargs,
|
|
175
194
|
)
|
|
176
195
|
|
|
196
|
+
with no_init_weights():
|
|
197
|
+
model = hf_auto_model_class.from_config(config, torch_dtype=torch_dtype)
|
|
198
|
+
|
|
199
|
+
# Quantize the model
|
|
200
|
+
update_layers_to_quantize(model, rbln_quantization)
|
|
201
|
+
|
|
202
|
+
# Load weights into the model
|
|
203
|
+
load_weights_from_files(model, safetensors, rbln_quantization)
|
|
204
|
+
|
|
177
205
|
return model
|
|
178
206
|
|
|
179
207
|
|
|
@@ -372,32 +400,26 @@ def canonicalize_checkpoint_items(
|
|
|
372
400
|
|
|
373
401
|
def load_weights_from_files(
|
|
374
402
|
model: torch.nn.Module,
|
|
375
|
-
|
|
376
|
-
n_layer: Optional[int] = None,
|
|
403
|
+
safetensors: List[Dict[str, torch.Tensor]],
|
|
377
404
|
rbln_quantization: Optional[RBLNQuantizationConfig] = None,
|
|
378
405
|
):
|
|
379
406
|
"""
|
|
380
|
-
Load safetensor file data directly into the model from provided safetensor files
|
|
381
|
-
filtering by layer if n_layer is provided.
|
|
407
|
+
Load safetensor file data directly into the model from provided safetensor files.
|
|
382
408
|
"""
|
|
383
409
|
|
|
384
410
|
model_params = dict(model.named_parameters(recurse=True))
|
|
385
411
|
model_buffers = dict(model.named_buffers(recurse=True))
|
|
386
412
|
|
|
387
|
-
target_layers = list(range(n_layer)) if n_layer is not None else None
|
|
388
|
-
|
|
389
413
|
unloaded_keys = []
|
|
390
414
|
loaded_input_scale = False
|
|
391
415
|
loaded_kv_scale = False
|
|
392
416
|
loaded_weight_scale = False
|
|
393
417
|
|
|
394
|
-
for
|
|
395
|
-
file_data = load_file(safetensor_file)
|
|
396
|
-
|
|
418
|
+
for safetensor in safetensors:
|
|
397
419
|
# Normalize all (key, tensor) pairs to the internal schema
|
|
398
420
|
normalized_items = canonicalize_checkpoint_items(
|
|
399
421
|
model=model,
|
|
400
|
-
items=
|
|
422
|
+
items=safetensor.items(),
|
|
401
423
|
rbln_quantization=rbln_quantization,
|
|
402
424
|
)
|
|
403
425
|
|
|
@@ -410,12 +432,6 @@ def load_weights_from_files(
|
|
|
410
432
|
if key.endswith("k_scale") or key.endswith("v_scale"):
|
|
411
433
|
loaded_kv_scale = True
|
|
412
434
|
|
|
413
|
-
# Filter by layer index if requested
|
|
414
|
-
if target_layers is not None:
|
|
415
|
-
parts = key.split(".")
|
|
416
|
-
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
|
417
|
-
continue
|
|
418
|
-
|
|
419
435
|
# Copy into parameters or buffers
|
|
420
436
|
if key in model_params:
|
|
421
437
|
# Ensure dtype compatibility
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: optimum-rbln
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.4a0
|
|
4
4
|
Summary: Optimum RBLN is the interface between the HuggingFace Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
optimum/rbln/__init__.py,sha256=32ouGKDGus9k5_kD27CxP8jIQOw66zpDTfS0xs1XlfE,18298
|
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
|
3
|
-
optimum/rbln/configuration_utils.py,sha256=
|
|
4
|
-
optimum/rbln/modeling.py,sha256=
|
|
5
|
-
optimum/rbln/modeling_base.py,sha256=
|
|
2
|
+
optimum/rbln/__version__.py,sha256=YNGYpHnDhFwKFL4ZTx3BIJGtmgon0Pv2G2E10GhWRaY,712
|
|
3
|
+
optimum/rbln/configuration_utils.py,sha256=KtbDM7HnFGiO0PsuvkrCE3R9NF6OJVmV_fyQcQNrmUk,34469
|
|
4
|
+
optimum/rbln/modeling.py,sha256=cAIPWEw5DGzUWeqjCbocRhU6OO3jyhVGW60AmBLh1Nw,14134
|
|
5
|
+
optimum/rbln/modeling_base.py,sha256=kQsBfUoDncNgR5P8_BvyzY6H_4YEXOBzN20lFmOZV_g,26190
|
|
6
6
|
optimum/rbln/diffusers/__init__.py,sha256=1tgU_xWA42BmInqu9bBz_5R_E9TGhhK3mI06YlaiTLg,7232
|
|
7
7
|
optimum/rbln/diffusers/modeling_diffusers.py,sha256=TAuMb7PSMjNwK7mh5ItE_CtAEgYeZKI27XkFFmxjHlQ,19902
|
|
8
8
|
optimum/rbln/diffusers/configurations/__init__.py,sha256=vMRnPY4s-Uju43xP038D2EA18X_mhy2YfsZVpSU-VoA,1322
|
|
@@ -105,10 +105,10 @@ optimum/rbln/transformers/models/colpali/configuration_colpali.py,sha256=eDWPVlo
|
|
|
105
105
|
optimum/rbln/transformers/models/colpali/modeling_colpali.py,sha256=v9rPLmNx-BQZhDFhKnr2kmARElTtKdFZCgFIU4m-HPw,15703
|
|
106
106
|
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=w3VZOIBYaHXVdnuhK4y0zWAj0IAv7_5LGTJYaz9oYmI,1056
|
|
107
107
|
optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py,sha256=H2i9Iefy-q5X-0BLWQ-CrxK8ZoT3p9t0lt_3r4TFSCY,15182
|
|
108
|
-
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=
|
|
109
|
-
optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py,sha256=
|
|
108
|
+
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=L5LArhjN36fTdiwrUABgn3cnS7hh4SVCF4FMHBbiLZU,42760
|
|
109
|
+
optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py,sha256=v3mfIlQImQkYYr-rPn7rQR3GYdVUhALRttEduLI7H9c,20012
|
|
110
110
|
optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py,sha256=4D89IF0yQju_Dp_vLJN_dBkpe2U_LMWaUciYx57D-0M,3379
|
|
111
|
-
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=
|
|
111
|
+
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=dAHV9NgdpXHyTJGT0lieXOB3Pzi_NPlR4rqmRtmAWzM,32412
|
|
112
112
|
optimum/rbln/transformers/models/depth_anything/__init__.py,sha256=xvPSIriMJWyNeVYoVB1Z7YqB4kkHOIkaHq7loNps-dk,756
|
|
113
113
|
optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py,sha256=JujBVEUa_zZDXNPr1y-B_PhK5SgFFcY8Ib4EoGjjtmE,989
|
|
114
114
|
optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py,sha256=tTmsVaW9Wb2WD3nKRLwp7swn3hbMvgwUEJwwVIfNYEc,1008
|
|
@@ -130,7 +130,7 @@ optimum/rbln/transformers/models/gemma3/__init__.py,sha256=6rugk3615SEt4lh7gduo_
|
|
|
130
130
|
optimum/rbln/transformers/models/gemma3/configuration_gemma3.py,sha256=rKjKJhyaIM7YoiLR-q8GAZKIQNzDzcb5X7qf_FJE72M,3398
|
|
131
131
|
optimum/rbln/transformers/models/gemma3/gemma3_architecture.py,sha256=fpLDAXCe5paWVsfc0tL59JkRQMRF-WNgIzOIb_QpSLU,6191
|
|
132
132
|
optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py,sha256=vYQ9sjRlkfamxZca_hVMQI0ylKeExsV02gOWaYVMjyg,9640
|
|
133
|
-
optimum/rbln/transformers/models/gemma3/modeling_gemma3.py,sha256=
|
|
133
|
+
optimum/rbln/transformers/models/gemma3/modeling_gemma3.py,sha256=TxbgkvW2Nv0VGdXNXnN_Beas6E_1D9NAH8f09Fo8t0E,24239
|
|
134
134
|
optimum/rbln/transformers/models/gpt2/__init__.py,sha256=SsawHMStE3wYRtqkH5EvdTFkCdX0LLmp-QSKFhEBrHo,740
|
|
135
135
|
optimum/rbln/transformers/models/gpt2/configuration_gpt2.py,sha256=iGdHfzG7plekZcIz-Z5U8lRE4SB8gbJJNcFQJ9l8Myg,1533
|
|
136
136
|
optimum/rbln/transformers/models/gpt2/gpt2_architecture.py,sha256=MyAWReXmyuHnDpW5HI_TI7psyJZxLujZ9KT5XnNm7nA,2802
|
|
@@ -182,7 +182,7 @@ optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=VOboPJF1rvvSVWkH
|
|
|
182
182
|
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
|
|
183
183
|
optimum/rbln/transformers/models/qwen2_5_vl/__init__.py,sha256=rAW3DKQUzGL6EMwa5r1iLu94yhpiZpk6zfoD7TtYXrc,865
|
|
184
184
|
optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py,sha256=1yyMFxh1SKsKR7rOjuotPvpSneN2_4a89bYfNk42370,4735
|
|
185
|
-
optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py,sha256=
|
|
185
|
+
optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py,sha256=hRvA37sPFC9xH1FqnFbtHS9rQOPwAvLYg4zl4oEyK-w,26639
|
|
186
186
|
optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py,sha256=i_UUWhKoFjJ5CCpgeWicqABM23TxMEKPQ354LoZ6iUU,7445
|
|
187
187
|
optimum/rbln/transformers/models/qwen3/__init__.py,sha256=tI4KwvXpD35dUUaa8aLUXpWoU9gJGcmKXeywOlH14ZE,746
|
|
188
188
|
optimum/rbln/transformers/models/qwen3/configuration_qwen3.py,sha256=BFRPggnH4VlsXlOa19C6KAID-bPgQ8ooQ29dvogh5zk,2102
|
|
@@ -227,7 +227,7 @@ optimum/rbln/transformers/models/xlm_roberta/__init__.py,sha256=O3o2KzJ8Li3QhB7G
|
|
|
227
227
|
optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py,sha256=wHRpGTXL9khYqSkKL1IgA7__6_lt9QpOz9tHumjK7fo,1260
|
|
228
228
|
optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py,sha256=EZd3flRUEE38DYtdqEnG70LV7fHhkamRZV51xrVyjYI,1093
|
|
229
229
|
optimum/rbln/transformers/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
230
|
-
optimum/rbln/transformers/utils/rbln_quantization.py,sha256=
|
|
230
|
+
optimum/rbln/transformers/utils/rbln_quantization.py,sha256=pORshQUgTInNaibUtd0HL-T8bKW5wuulZs2q0Oshppc,21659
|
|
231
231
|
optimum/rbln/utils/__init__.py,sha256=ieDBT2VFTt2E0M4v_POLBpuGW9LxSydpb_DuPd6PQqc,712
|
|
232
232
|
optimum/rbln/utils/decorator_utils.py,sha256=xu-TrsNi33SRC2a7DBsyoo6-pEQxWKZPZSmM9QlDe2Y,3745
|
|
233
233
|
optimum/rbln/utils/depreacate_utils.py,sha256=uKxl3ENUCNaZXPnaDQvNxrH8hUIWdBWfZH6BM7ZV__4,385
|
|
@@ -238,7 +238,7 @@ optimum/rbln/utils/model_utils.py,sha256=4k5879Kh75m3x_vS4-qOGfqsOiAvc2kdNFFfvsF
|
|
|
238
238
|
optimum/rbln/utils/runtime_utils.py,sha256=R6uXDbeJP03-FWdd4vthNe2D4aCra5n12E3WB1ifiGM,7933
|
|
239
239
|
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
|
240
240
|
optimum/rbln/utils/submodule.py,sha256=60NGLFvnhjP1DJg1opdb-FVQDsthcLCwWjW_1WQaasU,5280
|
|
241
|
-
optimum_rbln-0.8.
|
|
242
|
-
optimum_rbln-0.8.
|
|
243
|
-
optimum_rbln-0.8.
|
|
244
|
-
optimum_rbln-0.8.
|
|
241
|
+
optimum_rbln-0.8.4a0.dist-info/METADATA,sha256=QqrF_vPDFZO-DiTK0p328Y54qXyk1wApO86SAISpNcc,5299
|
|
242
|
+
optimum_rbln-0.8.4a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
243
|
+
optimum_rbln-0.8.4a0.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
244
|
+
optimum_rbln-0.8.4a0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|