optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_qwen3_moe import RBLNQwen3MoeForCausalLMConfig
|
|
16
|
+
from .modeling_qwen3_moe import RBLNQwen3MoeForCausalLM
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNQwen3MoeForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN Qwen3 Moe models.
|
|
21
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
22
|
+
Example usage:
|
|
23
|
+
```python
|
|
24
|
+
from optimum.rbln import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
|
|
25
|
+
# Create a configuration object
|
|
26
|
+
config = RBLNQwen3MoeForCausalLMConfig(
|
|
27
|
+
batch_size=1,
|
|
28
|
+
max_seq_len=262144,
|
|
29
|
+
tensor_parallel_size=4
|
|
30
|
+
)
|
|
31
|
+
# Use the configuration with from_pretrained
|
|
32
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
33
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
34
|
+
export=True,
|
|
35
|
+
rbln_config=config
|
|
36
|
+
)
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
16
|
+
from .qwen3_moe_architecture import Qwen3MoeWrapper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RBLNQwen3MoeForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
20
|
+
"""
|
|
21
|
+
The Qwen3 Moe is a Mixture-of-Experts (MoE) variant of Qwen3, available as a base model and an aligned chat model.
|
|
22
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
23
|
+
A class to convert and run pre-trained transformers based Qwen3MoeForCausalLM model on RBLN devices.
|
|
24
|
+
It implements the methods to convert a pre-trained transformers Qwen3MoeForCausalLM model into a RBLN transformer model by:
|
|
25
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
26
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
27
|
+
**Configuration:**
|
|
28
|
+
This model uses [`RBLNQwen3MoeForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
29
|
+
the `rbln_config` parameter should be an instance of [`RBLNQwen3MoeForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
30
|
+
See the [`RBLNQwen3MoeForCausalLMConfig`] class for all available configuration options.
|
|
31
|
+
Examples:
|
|
32
|
+
```python
|
|
33
|
+
from optimum.rbln import RBLNQwen3MoeForCausalLM
|
|
34
|
+
# Simple usage using rbln_* arguments
|
|
35
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
36
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
37
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
38
|
+
export=True,
|
|
39
|
+
rbln_batch_size=1,
|
|
40
|
+
rbln_tensor_parallel_size=4,
|
|
41
|
+
)
|
|
42
|
+
# Using a config dictionary
|
|
43
|
+
rbln_config = {
|
|
44
|
+
"batch_size": 1,
|
|
45
|
+
"max_seq_len": 262144,
|
|
46
|
+
"tensor_parallel_size": 4,
|
|
47
|
+
}
|
|
48
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
49
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
50
|
+
export=True,
|
|
51
|
+
rbln_config=rbln_config
|
|
52
|
+
)
|
|
53
|
+
# Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
|
|
54
|
+
from optimum.rbln import RBLNQwen3MoeForCausalLMConfig
|
|
55
|
+
config = RBLNQwen3MoeForCausalLMConfig(
|
|
56
|
+
batch_size=1,
|
|
57
|
+
max_seq_len=262144,
|
|
58
|
+
tensor_parallel_size=4
|
|
59
|
+
)
|
|
60
|
+
model = RBLNQwen3MoeForCausalLM.from_pretrained(
|
|
61
|
+
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
|
62
|
+
export=True,
|
|
63
|
+
rbln_config=config
|
|
64
|
+
)
|
|
65
|
+
```
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_decoder_wrapper_cls = Qwen3MoeWrapper
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
|
|
21
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Qwen3MoeWrapper(DecoderOnlyWrapper):
|
|
25
|
+
def get_rbln_layer_class(self):
|
|
26
|
+
return Qwen3MoeLayer
|
|
27
|
+
|
|
28
|
+
def get_rbln_attn_class(self):
|
|
29
|
+
return Qwen3MoeAttention
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Qwen3MoeAttention(DecoderOnlyAttention):
|
|
33
|
+
def __post_init__(self, self_attn):
|
|
34
|
+
self.q_proj = self_attn.q_proj
|
|
35
|
+
self.k_proj = self_attn.k_proj
|
|
36
|
+
self.v_proj = self_attn.v_proj
|
|
37
|
+
self.o_proj = self_attn.o_proj
|
|
38
|
+
self.q_norm = self_attn.q_norm
|
|
39
|
+
self.k_norm = self_attn.k_norm
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Qwen3MoeLayer(DecoderOnlyLayer):
|
|
43
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
|
|
44
|
+
super().__init__(layer, self_attn, lora_config)
|
|
45
|
+
self.mlp = (
|
|
46
|
+
Qwen3MoeSparseMoeBlock(layer.mlp)
|
|
47
|
+
if layer.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock"
|
|
48
|
+
else layer.mlp
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def get_mlp(self) -> nn.Module:
|
|
52
|
+
return self.mlp
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
56
|
+
def __init__(self, model: nn.Module):
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.num_experts = model.num_experts
|
|
59
|
+
self.top_k = model.top_k
|
|
60
|
+
self.norm_topk_prob = model.norm_topk_prob
|
|
61
|
+
self.gate = model.gate
|
|
62
|
+
self.experts = Qwen3MoeMLP(model.experts, self.top_k, self.norm_topk_prob)
|
|
63
|
+
|
|
64
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
65
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
66
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
67
|
+
|
|
68
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
69
|
+
router_logits = self.gate(hidden_states)
|
|
70
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
|
71
|
+
|
|
72
|
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
73
|
+
return final_hidden_states
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class Qwen3MoeMLP(nn.Module):
|
|
77
|
+
def __init__(self, expert_list, top_k, norm_topk_prob):
|
|
78
|
+
super().__init__()
|
|
79
|
+
self.hidden_size = expert_list[0].hidden_size
|
|
80
|
+
self.intermediate_size = expert_list[0].intermediate_size
|
|
81
|
+
self.top_k = top_k
|
|
82
|
+
self.norm_topk_prob = norm_topk_prob
|
|
83
|
+
self.num_experts = len(expert_list)
|
|
84
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
|
|
85
|
+
self.up_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
|
|
86
|
+
self.down_proj = nn.Linear(self.num_experts * self.intermediate_size, self.hidden_size, bias=False)
|
|
87
|
+
self.gate_proj.weight.data = torch.stack([expert.gate_proj.weight.data for expert in expert_list], dim=0)
|
|
88
|
+
self.up_proj.weight.data = torch.stack([expert.up_proj.weight.data for expert in expert_list], dim=0)
|
|
89
|
+
self.down_proj.weight.data = torch.stack([expert.down_proj.weight.data for expert in expert_list], dim=0)
|
|
90
|
+
|
|
91
|
+
def forward(self, x, router_logits):
|
|
92
|
+
return torch.ops.rbln_custom_ops.custom_moe_glu(
|
|
93
|
+
x,
|
|
94
|
+
self.gate_proj.weight,
|
|
95
|
+
self.up_proj.weight,
|
|
96
|
+
self.down_proj.weight,
|
|
97
|
+
router_logits,
|
|
98
|
+
self.top_k,
|
|
99
|
+
self.norm_topk_prob,
|
|
100
|
+
)
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
-
from typing import Optional
|
|
16
|
+
from typing import Any, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
|
19
19
|
|
|
@@ -26,17 +26,23 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
|
|
|
26
26
|
RBLN-optimized ResNet models for image classification tasks.
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
|
-
def __init__(
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
32
|
+
batch_size: Optional[int] = None,
|
|
33
|
+
output_hidden_states: Optional[bool] = None,
|
|
34
|
+
**kwargs: Any,
|
|
35
|
+
):
|
|
30
36
|
"""
|
|
31
37
|
Args:
|
|
32
38
|
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
|
33
39
|
Can be an integer for square images or a tuple (height, width).
|
|
34
40
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
35
|
-
output_hidden_states (bool, optional)
|
|
41
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
|
|
36
42
|
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
37
43
|
|
|
38
44
|
Raises:
|
|
39
45
|
ValueError: If batch_size is not a positive integer.
|
|
40
46
|
"""
|
|
41
|
-
super().__init__(**kwargs)
|
|
47
|
+
super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
|
|
42
48
|
self.output_hidden_states = output_hidden_states
|
|
@@ -268,13 +268,12 @@ class Seq2SeqDecoder(torch.nn.Module):
|
|
|
268
268
|
|
|
269
269
|
def __init__(self, model, layers, **kwargs):
|
|
270
270
|
super().__init__()
|
|
271
|
-
self._original_mod = model
|
|
272
271
|
self.layers = nn.ModuleList(layers)
|
|
273
272
|
self.embed_tokens = model.embed_tokens
|
|
274
|
-
self.final_layer_norm = getattr(model, "final_layer_norm", None)
|
|
275
|
-
self.__post_init__(**kwargs)
|
|
273
|
+
self.final_layer_norm = getattr(model, "final_layer_norm", None) or getattr(model, "layer_norm", None)
|
|
274
|
+
self.__post_init__(model, **kwargs)
|
|
276
275
|
|
|
277
|
-
def __post_init__(self, **kwargs):
|
|
276
|
+
def __post_init__(self, model: nn.Module, **kwargs):
|
|
278
277
|
"""
|
|
279
278
|
Abstract method intended to be overridden by subclasses to modify or override
|
|
280
279
|
the attributes of the original model after initialization.
|
|
@@ -344,12 +343,11 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
|
344
343
|
|
|
345
344
|
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
346
345
|
super().__init__()
|
|
347
|
-
self._original_mod = decoder_layer
|
|
348
346
|
self.self_attn = self_attn
|
|
349
347
|
self.cross_attn = cross_attn
|
|
350
|
-
self.__post_init__()
|
|
348
|
+
self.__post_init__(decoder_layer)
|
|
351
349
|
|
|
352
|
-
def __post_init__(self, **kwargs):
|
|
350
|
+
def __post_init__(self, decoder_layer: nn.Module, **kwargs):
|
|
353
351
|
"""
|
|
354
352
|
Abstract method intended to be overridden by subclasses to modify or override
|
|
355
353
|
the attributes of the original model after initialization.
|
|
@@ -423,10 +421,9 @@ class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
|
423
421
|
class Seq2SeqSelfAttention(nn.Module):
|
|
424
422
|
def __init__(self, attn, **kwargs):
|
|
425
423
|
super().__init__()
|
|
426
|
-
self.
|
|
427
|
-
self.__post_init__(**kwargs)
|
|
424
|
+
self.__post_init__(attn, **kwargs)
|
|
428
425
|
|
|
429
|
-
def __post_init__(self, **kwargs):
|
|
426
|
+
def __post_init__(self, attn: nn.Module, **kwargs):
|
|
430
427
|
"""
|
|
431
428
|
Abstract method intended to be overridden by subclasses to modify or override
|
|
432
429
|
the attributes of the original model after initialization.
|
|
@@ -495,8 +492,13 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
|
495
492
|
class Seq2SeqCrossAttention(nn.Module):
|
|
496
493
|
def __init__(self, attn, **kwargs):
|
|
497
494
|
super().__init__()
|
|
498
|
-
self.
|
|
499
|
-
|
|
495
|
+
self.__post_init__(attn, **kwargs)
|
|
496
|
+
|
|
497
|
+
def __post_init__(self, attn: nn.Module, **kwargs):
|
|
498
|
+
"""
|
|
499
|
+
Optional post-init hook for subclasses (e.g., to register q/k/v/out projections).
|
|
500
|
+
"""
|
|
501
|
+
pass
|
|
500
502
|
|
|
501
503
|
def forward(
|
|
502
504
|
self,
|
|
@@ -21,6 +21,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
|
21
21
|
from ....configuration_utils import RBLNCompileConfig
|
|
22
22
|
from ....modeling import RBLNModel
|
|
23
23
|
from ....utils.logging import get_logger
|
|
24
|
+
from ...modeling_outputs import _validate_output_attentions, _validate_output_hidden_states
|
|
24
25
|
from .configuration_siglip import RBLNSiglipVisionModelConfig
|
|
25
26
|
|
|
26
27
|
|
|
@@ -52,7 +53,7 @@ class _SiglipVisionModel(torch.nn.Module):
|
|
|
52
53
|
interpolate_pos_encoding=self.interpolate_pos_encoding,
|
|
53
54
|
output_attentions=self.output_attentions,
|
|
54
55
|
)
|
|
55
|
-
return
|
|
56
|
+
return enc_out
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
class RBLNSiglipVisionModel(RBLNModel):
|
|
@@ -138,23 +139,8 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
138
139
|
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
139
140
|
"""
|
|
140
141
|
|
|
141
|
-
output_attentions = output_attentions
|
|
142
|
-
output_hidden_states = (
|
|
143
|
-
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
if output_attentions != self.rbln_config.output_attentions:
|
|
147
|
-
raise ValueError(
|
|
148
|
-
f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
|
|
149
|
-
f"Please compile again with the correct argument."
|
|
150
|
-
)
|
|
151
|
-
|
|
152
|
-
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
153
|
-
raise ValueError(
|
|
154
|
-
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
155
|
-
f"Please compile again with the correct argument."
|
|
156
|
-
)
|
|
157
|
-
|
|
142
|
+
output_attentions = _validate_output_attentions(output_attentions, self.rbln_config)
|
|
143
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
158
144
|
if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
|
|
159
145
|
raise ValueError(
|
|
160
146
|
f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
|
|
@@ -32,11 +32,6 @@ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
|
|
|
32
32
|
Raises:
|
|
33
33
|
ValueError: If batch_size is not a positive integer.
|
|
34
34
|
"""
|
|
35
|
-
super().__init__(**kwargs)
|
|
36
|
-
self.batch_size = batch_size or 1
|
|
37
|
-
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
38
|
-
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
39
|
-
|
|
40
|
-
self.image_size = image_size
|
|
35
|
+
super().__init__(batch_size=batch_size, image_size=image_size, **kwargs)
|
|
41
36
|
self.output_hidden_states = output_hidden_states
|
|
42
37
|
self.output_attentions = output_attentions
|
|
@@ -111,9 +111,9 @@ class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
|
111
111
|
class T5Decoder(Seq2SeqDecoder):
|
|
112
112
|
has_pos_emb = False
|
|
113
113
|
|
|
114
|
-
def __post_init__(self, dec_max_seq_len: int = None):
|
|
115
|
-
self.invert_attention_mask =
|
|
116
|
-
self._dec_position_bias = self.precompute_dec_position_bias(
|
|
114
|
+
def __post_init__(self, model: nn.Module, dec_max_seq_len: int = None):
|
|
115
|
+
self.invert_attention_mask = model.invert_attention_mask
|
|
116
|
+
self._dec_position_bias = self.precompute_dec_position_bias(model, dec_max_seq_len)
|
|
117
117
|
|
|
118
118
|
def precompute_dec_position_bias(self, model, dec_max_length):
|
|
119
119
|
attn_layer = model.block[0].layer[0].SelfAttention
|
|
@@ -145,13 +145,12 @@ class T5Decoder(Seq2SeqDecoder):
|
|
|
145
145
|
class T5Block(Seq2SeqDecoderLayer):
|
|
146
146
|
def __init__(self, decoder_layer, self_attn):
|
|
147
147
|
super().__init__(decoder_layer, self_attn, cross_attn=None)
|
|
148
|
-
self.__post_init__()
|
|
149
148
|
|
|
150
|
-
def __post_init__(self):
|
|
151
|
-
self.self_attn_layer_norm =
|
|
152
|
-
self.encoder_attn_layer_norm =
|
|
153
|
-
self.cross_attn = T5CrossAttention(
|
|
154
|
-
self.ff_layer =
|
|
149
|
+
def __post_init__(self, decoder_layer: nn.Module):
|
|
150
|
+
self.self_attn_layer_norm = decoder_layer.layer[0].layer_norm
|
|
151
|
+
self.encoder_attn_layer_norm = decoder_layer.layer[1].layer_norm
|
|
152
|
+
self.cross_attn = T5CrossAttention(decoder_layer.layer[1].EncDecAttention)
|
|
153
|
+
self.ff_layer = decoder_layer.layer[2]
|
|
155
154
|
|
|
156
155
|
def pre_self_attn_layer_norm(self, hidden_states):
|
|
157
156
|
return self.self_attn_layer_norm(hidden_states)
|
|
@@ -167,13 +166,13 @@ class T5Block(Seq2SeqDecoderLayer):
|
|
|
167
166
|
|
|
168
167
|
|
|
169
168
|
class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
170
|
-
def __post_init__(self):
|
|
171
|
-
self.q_proj =
|
|
172
|
-
self.k_proj =
|
|
173
|
-
self.v_proj =
|
|
174
|
-
self.out_proj =
|
|
175
|
-
self.num_heads =
|
|
176
|
-
self.head_dim =
|
|
169
|
+
def __post_init__(self, attn: nn.Module):
|
|
170
|
+
self.q_proj = attn.q
|
|
171
|
+
self.k_proj = attn.k
|
|
172
|
+
self.v_proj = attn.v
|
|
173
|
+
self.out_proj = attn.o
|
|
174
|
+
self.num_heads = attn.n_heads
|
|
175
|
+
self.head_dim = attn.key_value_proj_dim
|
|
177
176
|
self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
|
|
178
177
|
|
|
179
178
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py
CHANGED
|
@@ -140,7 +140,6 @@ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
|
|
|
140
140
|
class TimeSeriesTransformersDecoder(nn.Module):
|
|
141
141
|
def __init__(self, model, layers, **kwargs):
|
|
142
142
|
super().__init__()
|
|
143
|
-
self._original_mod = model
|
|
144
143
|
self.config = model.config
|
|
145
144
|
self.layers = nn.ModuleList(layers)
|
|
146
145
|
self.value_embedding = model.value_embedding
|
|
@@ -190,7 +189,6 @@ class TimeSeriesTransformersDecoder(nn.Module):
|
|
|
190
189
|
class TimeSeriesTransformersDecoderLayer(nn.Module):
|
|
191
190
|
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
192
191
|
super().__init__()
|
|
193
|
-
self._original_mod = decoder_layer
|
|
194
192
|
self.self_attn = self_attn
|
|
195
193
|
self.encoder_attn = cross_attn
|
|
196
194
|
self.embed_dim = decoder_layer.embed_dim
|
|
@@ -245,7 +243,6 @@ class TimeSeriesTransformersDecoderLayer(nn.Module):
|
|
|
245
243
|
class TimeSeriesTransformersAttention(nn.Module):
|
|
246
244
|
def __init__(self, attn, num_parallel_samples):
|
|
247
245
|
super().__init__()
|
|
248
|
-
self._original_mod = attn
|
|
249
246
|
self.q_proj = attn.q_proj
|
|
250
247
|
self.k_proj = attn.k_proj
|
|
251
248
|
self.v_proj = attn.v_proj
|
|
@@ -51,22 +51,22 @@ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
|
51
51
|
return_segments: Optional[bool] = None,
|
|
52
52
|
return_timestamps: Optional[bool] = None,
|
|
53
53
|
return_token_timestamps: Optional[bool] = None,
|
|
54
|
-
**kwargs,
|
|
54
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
55
55
|
) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
|
|
56
56
|
"""
|
|
57
57
|
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
58
58
|
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
|
|
59
59
|
|
|
60
60
|
Args:
|
|
61
|
-
input_features(torch.Tensor, optional): The input features to the model.
|
|
62
|
-
attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
|
|
63
|
-
generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
61
|
+
input_features (torch.Tensor, optional): The input features to the model.
|
|
62
|
+
attention_mask (torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
|
|
63
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
64
64
|
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
65
65
|
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
66
|
-
return_segments(bool, optional): Whether to return segments.
|
|
67
|
-
return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
|
|
68
|
-
return_token_timestamps(bool, optional): Whether to return token timestamps.
|
|
69
|
-
kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
|
|
66
|
+
return_segments (bool, optional): Whether to return segments.
|
|
67
|
+
return_timestamps (bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
|
|
68
|
+
return_token_timestamps (bool, optional): Whether to return token timestamps.
|
|
69
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function.
|
|
70
70
|
|
|
71
71
|
Returns:
|
|
72
72
|
Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
|
|
@@ -154,7 +154,6 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
|
154
154
|
class WhisperDecoder(nn.Module):
|
|
155
155
|
def __init__(self, model, layers, **kwargs):
|
|
156
156
|
super().__init__()
|
|
157
|
-
self._original_mod = model
|
|
158
157
|
self.layers = nn.ModuleList(layers)
|
|
159
158
|
self.embed_tokens = model.embed_tokens
|
|
160
159
|
self.layer_norm = model.layer_norm
|
|
@@ -210,7 +209,6 @@ class WhisperDecoder(nn.Module):
|
|
|
210
209
|
class WhisperDecoderLayer(nn.Module):
|
|
211
210
|
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
212
211
|
super().__init__()
|
|
213
|
-
self._original_mod = decoder_layer
|
|
214
212
|
self.self_attn = self_attn
|
|
215
213
|
self.encoder_attn = cross_attn
|
|
216
214
|
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
|
@@ -263,7 +261,6 @@ class WhisperDecoderLayer(nn.Module):
|
|
|
263
261
|
class WhisperAttention(nn.Module):
|
|
264
262
|
def __init__(self, attn):
|
|
265
263
|
super().__init__()
|
|
266
|
-
self._original_mod = attn
|
|
267
264
|
self.q_proj = attn.q_proj
|
|
268
265
|
self.k_proj = attn.k_proj
|
|
269
266
|
self.v_proj = attn.v_proj
|
|
@@ -221,11 +221,12 @@ def load_weight_files(
|
|
|
221
221
|
cache_dir: Optional[str] = None,
|
|
222
222
|
force_download: bool = False,
|
|
223
223
|
local_files_only: bool = False,
|
|
224
|
+
exception_keywords: Optional[List[str]] = None,
|
|
224
225
|
) -> list[str]:
|
|
225
226
|
"""
|
|
226
227
|
Discover and download safetensors files for the given model id.
|
|
227
228
|
"""
|
|
228
|
-
|
|
229
|
+
exception_keywords = exception_keywords or []
|
|
229
230
|
if os.path.isdir(model_id):
|
|
230
231
|
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
|
231
232
|
else:
|
|
@@ -237,17 +238,24 @@ def load_weight_files(
|
|
|
237
238
|
|
|
238
239
|
for file in repo_files:
|
|
239
240
|
if file.endswith(".safetensors"):
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
241
|
+
exculde = False
|
|
242
|
+
for except_key in exception_keywords:
|
|
243
|
+
if except_key in file:
|
|
244
|
+
exculde = True
|
|
245
|
+
break
|
|
246
|
+
|
|
247
|
+
if not exculde:
|
|
248
|
+
# Download the safetensors file
|
|
249
|
+
downloaded_file = hf_hub_download(
|
|
250
|
+
repo_id=model_id,
|
|
251
|
+
filename=file,
|
|
252
|
+
revision=revision,
|
|
253
|
+
token=use_auth_token,
|
|
254
|
+
cache_dir=cache_dir,
|
|
255
|
+
force_download=force_download,
|
|
256
|
+
local_files_only=local_files_only,
|
|
257
|
+
)
|
|
258
|
+
safetensor_files.append(downloaded_file)
|
|
251
259
|
except Exception as e:
|
|
252
260
|
logger.error(f"Failed to download safetensors files from Hugging Face Hub: {e}")
|
|
253
261
|
raise e
|
|
@@ -194,7 +194,7 @@ def deprecate_kwarg(
|
|
|
194
194
|
message = f"{message} {additional_message}"
|
|
195
195
|
|
|
196
196
|
# update minimum_action if argument is ALREADY deprecated (current version >= deprecated version)
|
|
197
|
-
if is_greater_or_equal_version:
|
|
197
|
+
if is_greater_or_equal_version and message is not None:
|
|
198
198
|
# change to NOTIFY -> RAISE in case we want to raise error for already deprecated arguments
|
|
199
199
|
if raise_if_greater_or_equal_version:
|
|
200
200
|
minimum_action = Action.RAISE
|
|
@@ -211,3 +211,80 @@ def deprecate_kwarg(
|
|
|
211
211
|
return wrapped_func
|
|
212
212
|
|
|
213
213
|
return wrapper
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def deprecate_method(
|
|
217
|
+
version: str,
|
|
218
|
+
new_method: Optional[str] = None,
|
|
219
|
+
raise_if_greater_or_equal_version: bool = True,
|
|
220
|
+
additional_message: Optional[str] = None,
|
|
221
|
+
):
|
|
222
|
+
"""
|
|
223
|
+
Decorator to mark a method as deprecated, optionally pointing to a replacement method.
|
|
224
|
+
This decorator allows you to:
|
|
225
|
+
- Notify users when a method is deprecated.
|
|
226
|
+
- Optionally specify a new method name that should be used instead.
|
|
227
|
+
- Raise an error if the deprecated method is called after the specified version.
|
|
228
|
+
Parameters:
|
|
229
|
+
version (`str`):
|
|
230
|
+
The version in which the method was (or will be) deprecated.
|
|
231
|
+
new_method (`Optional[str]`, *optional*):
|
|
232
|
+
The name of the new method to use instead. If specified, users will be directed to use this method.
|
|
233
|
+
raise_if_greater_or_equal_version (`bool`, *optional*, defaults to `True`):
|
|
234
|
+
Whether to raise `ValueError` if current `optimum.rbln` version is greater than or equal to the deprecated version.
|
|
235
|
+
additional_message (`Optional[str]`, *optional*):
|
|
236
|
+
An additional message to append to the default deprecation message.
|
|
237
|
+
Returns:
|
|
238
|
+
Callable:
|
|
239
|
+
A wrapped function that handles the deprecation warning or error.
|
|
240
|
+
Examples:
|
|
241
|
+
>>> class MyClass:
|
|
242
|
+
... @deprecate_method(version="0.12.0", new_method="from_pretrained")
|
|
243
|
+
... def load(self, path):
|
|
244
|
+
... return self.from_pretrained(path)
|
|
245
|
+
"""
|
|
246
|
+
|
|
247
|
+
deprecated_version = packaging.version.parse(version)
|
|
248
|
+
current_version = packaging.version.parse(__version__)
|
|
249
|
+
is_greater_or_equal_version = current_version >= deprecated_version
|
|
250
|
+
|
|
251
|
+
if is_greater_or_equal_version:
|
|
252
|
+
version_message = f"and removed starting from version {version}"
|
|
253
|
+
else:
|
|
254
|
+
version_message = f"and will be removed in version {version}"
|
|
255
|
+
|
|
256
|
+
def wrapper(func):
|
|
257
|
+
sig = inspect.signature(func)
|
|
258
|
+
function_named_args = set(sig.parameters.keys())
|
|
259
|
+
is_instance_method = "self" in function_named_args
|
|
260
|
+
is_class_method = "cls" in function_named_args
|
|
261
|
+
|
|
262
|
+
@wraps(func)
|
|
263
|
+
def wrapped_func(*args, **kwargs):
|
|
264
|
+
# Get class + method name for better warning message
|
|
265
|
+
method_name = func.__name__
|
|
266
|
+
if is_instance_method:
|
|
267
|
+
method_name = f"{args[0].__class__.__name__}.{method_name}"
|
|
268
|
+
elif is_class_method:
|
|
269
|
+
method_name = f"{args[0].__name__}.{method_name}"
|
|
270
|
+
|
|
271
|
+
# Build deprecation message
|
|
272
|
+
if new_method is not None:
|
|
273
|
+
message = f"`{method_name}` is deprecated {version_message}. Use `{new_method}` instead."
|
|
274
|
+
else:
|
|
275
|
+
message = f"`{method_name}` is deprecated {version_message}."
|
|
276
|
+
|
|
277
|
+
if additional_message is not None:
|
|
278
|
+
message = f"{message} {additional_message}"
|
|
279
|
+
|
|
280
|
+
# Determine action based on version
|
|
281
|
+
if is_greater_or_equal_version and raise_if_greater_or_equal_version:
|
|
282
|
+
raise ValueError(message)
|
|
283
|
+
else:
|
|
284
|
+
logger.warning(message, stacklevel=2)
|
|
285
|
+
|
|
286
|
+
return func(*args, **kwargs)
|
|
287
|
+
|
|
288
|
+
return wrapped_func
|
|
289
|
+
|
|
290
|
+
return wrapper
|