optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +4 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +48 -0
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/models/__init__.py +35 -14
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
- optimum/rbln/utils/depreacate_utils.py +16 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
17
|
+
|
|
18
|
+
from transformers import PegasusForConditionalGeneration, PreTrainedModel
|
|
19
|
+
|
|
20
|
+
from ....utils.logging import get_logger
|
|
21
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
|
22
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
|
23
|
+
from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig
|
|
24
|
+
from .pegasus_architecture import PegasusWrapper
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = get_logger()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from transformers import PreTrainedModel
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
35
|
+
"""
|
|
36
|
+
RBLN optimized PEGASUS model for feature extraction tasks.
|
|
37
|
+
|
|
38
|
+
This class provides hardware-accelerated inference for PEGASUS encoder models
|
|
39
|
+
on RBLN devices, optimized for feature extraction use cases.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
44
|
+
"""
|
|
45
|
+
RBLN optimized PEGASUS model for conditional text generation tasks.
|
|
46
|
+
|
|
47
|
+
This class provides hardware-accelerated inference for PEGASUS models
|
|
48
|
+
on RBLN devices, supporting sequence-to-sequence generation tasks
|
|
49
|
+
such as summarization, translation, and text generation.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
support_causal_attn = True
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
|
|
56
|
+
return PegasusWrapper(
|
|
57
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def __getattr__(self, __name: str) -> Any:
|
|
61
|
+
def redirect(func):
|
|
62
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
63
|
+
|
|
64
|
+
val = getattr(PegasusForConditionalGeneration, __name)
|
|
65
|
+
|
|
66
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
67
|
+
return redirect(val)
|
|
68
|
+
|
|
69
|
+
return val
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from transformers.modeling_attn_mask_utils import (
|
|
20
|
+
_prepare_4d_attention_mask,
|
|
21
|
+
)
|
|
22
|
+
from transformers.utils import logging
|
|
23
|
+
|
|
24
|
+
from ..seq2seq.seq2seq_architecture import (
|
|
25
|
+
Seq2SeqCrossAttention,
|
|
26
|
+
Seq2SeqDecoder,
|
|
27
|
+
Seq2SeqDecoderLayer,
|
|
28
|
+
Seq2SeqDecoderWrapper,
|
|
29
|
+
Seq2SeqEncoderWrapper,
|
|
30
|
+
Seq2SeqForConditionalGeneration,
|
|
31
|
+
Seq2SeqSelfAttention,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
logger = logging.get_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PegasusWrapper:
|
|
39
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
|
|
40
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
|
41
|
+
self.decoder = PegasusDecoderWrapper(model, use_attention_mask=use_attention_mask)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class PegasusDecoderWrapper(Seq2SeqDecoderWrapper):
|
|
45
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
|
46
|
+
new_layers = []
|
|
47
|
+
for layer in model.get_decoder().layers:
|
|
48
|
+
self_attn = PegasusSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
|
|
49
|
+
cross_attn = PegasusCrossAttention(layer.encoder_attn)
|
|
50
|
+
new_layers.append(PegasusDecoderLayer(layer, self_attn, cross_attn))
|
|
51
|
+
|
|
52
|
+
decoder_model = PegasusDecoder(model.get_decoder(), new_layers)
|
|
53
|
+
new_model = PegasusForConditionalGeneration(model, decoder_model)
|
|
54
|
+
|
|
55
|
+
return new_model
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PegasusDecoder(Seq2SeqDecoder):
|
|
63
|
+
has_pos_emb = True
|
|
64
|
+
|
|
65
|
+
def __post_init__(self):
|
|
66
|
+
self.embed_positions = self._original_mod.embed_positions
|
|
67
|
+
self.embed_scale = getattr(self._original_mod, "embed_scale", None)
|
|
68
|
+
self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
|
|
69
|
+
|
|
70
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
|
71
|
+
if attention_mask is not None:
|
|
72
|
+
attention_mask = attention_mask[:, None, None, :]
|
|
73
|
+
encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
|
|
74
|
+
|
|
75
|
+
return attention_mask, encoder_attention_mask
|
|
76
|
+
|
|
77
|
+
def apply_position_embedding(self, inputs_embeds, cache_position):
|
|
78
|
+
hidden_all = []
|
|
79
|
+
for i in range(inputs_embeds.shape[0]):
|
|
80
|
+
positions_idx = cache_position[i]
|
|
81
|
+
position_weight = self.embed_positions.weight
|
|
82
|
+
position = position_weight[positions_idx]
|
|
83
|
+
batch_hidden = position + inputs_embeds[i]
|
|
84
|
+
hidden_all.append(batch_hidden)
|
|
85
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
|
86
|
+
|
|
87
|
+
return hidden_states
|
|
88
|
+
|
|
89
|
+
def get_embedding(self):
|
|
90
|
+
if self.embed_scale is not None:
|
|
91
|
+
return lambda x: self.embed_tokens(x) * self.embed_scale
|
|
92
|
+
else:
|
|
93
|
+
return self.embed_tokens
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class PegasusLayerFF(nn.Module):
|
|
97
|
+
def __init__(self, decoder_layer):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.fc1 = decoder_layer.fc1
|
|
100
|
+
self.fc2 = decoder_layer.fc2
|
|
101
|
+
self.activation_fn = decoder_layer.activation_fn
|
|
102
|
+
self.layer_norm = decoder_layer.final_layer_norm
|
|
103
|
+
|
|
104
|
+
def forward(self, hidden_states):
|
|
105
|
+
# Residual Connection
|
|
106
|
+
residual = hidden_states
|
|
107
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
108
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
109
|
+
hidden_states = self.fc2(hidden_states)
|
|
110
|
+
hidden_states = residual + hidden_states
|
|
111
|
+
return hidden_states
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class PegasusDecoderLayer(Seq2SeqDecoderLayer):
|
|
115
|
+
def __post_init__(self):
|
|
116
|
+
self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
|
|
117
|
+
self.encoder_attn = self._original_mod.encoder_attn
|
|
118
|
+
self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
|
|
119
|
+
self.ff_layer = PegasusLayerFF(self._original_mod)
|
|
120
|
+
|
|
121
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
|
122
|
+
return self.self_attn_layer_norm(hidden_states)
|
|
123
|
+
|
|
124
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
|
125
|
+
return hidden_states
|
|
126
|
+
|
|
127
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
|
128
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
|
129
|
+
|
|
130
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
|
131
|
+
return hidden_states
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class PegasusSelfAttention(Seq2SeqSelfAttention):
|
|
135
|
+
def __post_init__(self, use_attention_mask: bool = True):
|
|
136
|
+
self.q_proj = self._original_mod.q_proj
|
|
137
|
+
self.k_proj = self._original_mod.k_proj
|
|
138
|
+
self.v_proj = self._original_mod.v_proj
|
|
139
|
+
self.out_proj = self._original_mod.out_proj
|
|
140
|
+
self.num_heads = self._original_mod.num_heads
|
|
141
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
|
142
|
+
self.scaling = self.head_dim**-0.5
|
|
143
|
+
if use_attention_mask:
|
|
144
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
|
145
|
+
else:
|
|
146
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
|
|
147
|
+
|
|
148
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
149
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
|
150
|
+
key_states = self.k_proj(hidden_states)
|
|
151
|
+
value_states = self.v_proj(hidden_states)
|
|
152
|
+
return query_states, key_states, value_states
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class PegasusCrossAttention(Seq2SeqCrossAttention):
|
|
156
|
+
def __post_init__(self):
|
|
157
|
+
self.q_proj = self._original_mod.q_proj
|
|
158
|
+
self.k_proj = self._original_mod.k_proj
|
|
159
|
+
self.v_proj = self._original_mod.v_proj
|
|
160
|
+
self.out_proj = self._original_mod.out_proj
|
|
161
|
+
self.num_heads = self._original_mod.num_heads
|
|
162
|
+
self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
|
|
163
|
+
self.embed_dim = self._original_mod.embed_dim
|
|
@@ -12,5 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_phi import RBLNPhiForCausalLMConfig
|
|
16
|
-
from .modeling_phi import RBLNPhiForCausalLM
|
|
15
|
+
from .configuration_phi import RBLNPhiForCausalLMConfig, RBLNPhiModelConfig
|
|
16
|
+
from .modeling_phi import RBLNPhiForCausalLM, RBLNPhiModel
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
@@ -40,3 +40,11 @@ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
|
40
40
|
)
|
|
41
41
|
```
|
|
42
42
|
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNPhiModelConfig(RBLNDecoderOnlyModelConfig):
|
|
46
|
+
"""
|
|
47
|
+
Configuration class for RBLN Phi models.
|
|
48
|
+
|
|
49
|
+
This class is an alias of RBLNDecoderOnlyModelConfig.
|
|
50
|
+
"""
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from ....utils import logging
|
|
16
|
-
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
|
16
|
+
from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
|
|
17
17
|
from .phi_architecture import PhiWrapper
|
|
18
18
|
|
|
19
19
|
|
|
@@ -81,3 +81,12 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
81
81
|
"""
|
|
82
82
|
|
|
83
83
|
_decoder_wrapper_cls = PhiWrapper
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class RBLNPhiModel(RBLNDecoderOnlyModel):
|
|
87
|
+
"""
|
|
88
|
+
The Phi Model transformer without a language modeling head.
|
|
89
|
+
This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
_decoder_wrapper_cls = PhiWrapper
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Optional, Tuple
|
|
15
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from transformers import PhiForCausalLM
|
|
@@ -27,7 +27,7 @@ from ..decoderonly.decoderonly_architecture import (
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
|
-
from transformers import PhiForCausalLM
|
|
30
|
+
from transformers import PhiForCausalLM, PhiModel
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class PhiWrapper(DecoderOnlyWrapper):
|
|
@@ -40,11 +40,11 @@ class PhiWrapper(DecoderOnlyWrapper):
|
|
|
40
40
|
def get_rbln_model_class(self):
|
|
41
41
|
return PhiModel
|
|
42
42
|
|
|
43
|
-
def get_model_layer(self,
|
|
44
|
-
return
|
|
43
|
+
def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
|
|
44
|
+
return model.model if self.is_causal_lm else model
|
|
45
45
|
|
|
46
|
-
def get_decoder_layers(self,
|
|
47
|
-
return
|
|
46
|
+
def get_decoder_layers(self, model: Union["PhiForCausalLM", "PhiModel"]):
|
|
47
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
class PhiAttention(DecoderOnlyAttention):
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_pixtral import RBLNPixtralVisionModelConfig
|
|
16
|
+
from .modeling_pixtral import RBLNPixtralVisionModel
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNPixtralVisionModelConfig(RBLNModelConfig):
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
max_image_size: Tuple = None,
|
|
24
|
+
batch_size: Optional[int] = None,
|
|
25
|
+
output_hidden_states: Optional[bool] = None,
|
|
26
|
+
**kwargs: Dict[str, Any],
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
max_image_size (Tuple): The size of max input images. A tuple (max_height, max_width)
|
|
31
|
+
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
|
32
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If batch_size is not a positive integer.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.batch_size = batch_size or 1
|
|
39
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
40
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
41
|
+
|
|
42
|
+
self.max_image_size = max_image_size
|
|
43
|
+
self.output_hidden_states = output_hidden_states
|