optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/ops/flash_attn.py
CHANGED
@@ -12,71 +12,165 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
|
-
|
17
15
|
import torch
|
18
|
-
from
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
torch.
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
16
|
+
from torch import Tensor
|
17
|
+
|
18
|
+
|
19
|
+
@torch.library.custom_op(
|
20
|
+
"rbln_custom_ops::paged_flash_attn_decode",
|
21
|
+
mutates_args=(["kcache", "vcache"]),
|
22
|
+
)
|
23
|
+
def paged_flash_attn_decode(
|
24
|
+
q: Tensor,
|
25
|
+
k: Tensor,
|
26
|
+
v: Tensor,
|
27
|
+
mask: Tensor,
|
28
|
+
kcache: Tensor,
|
29
|
+
vcache: Tensor,
|
30
|
+
seq: Tensor,
|
31
|
+
scale: Tensor,
|
32
|
+
block_table: Tensor,
|
33
|
+
block_size: int,
|
34
|
+
partition: int,
|
35
|
+
) -> Tensor:
|
36
|
+
"""Defines the computation pattern for fused flash attention with KV cache for decoding.
|
37
|
+
|
38
|
+
Returns a tensor with the same shape as q.
|
39
|
+
"""
|
40
|
+
return torch.empty_like(q)
|
41
|
+
|
42
|
+
|
43
|
+
@paged_flash_attn_decode.register_fake
|
44
|
+
def paged_flash_attn_decode_fake(
|
45
|
+
q: Tensor,
|
46
|
+
k: Tensor,
|
47
|
+
v: Tensor,
|
48
|
+
mask: Tensor,
|
49
|
+
kcache: Tensor,
|
50
|
+
vcache: Tensor,
|
51
|
+
seq: Tensor,
|
52
|
+
scale: Tensor,
|
53
|
+
block_table: Tensor,
|
54
|
+
block_size: int,
|
55
|
+
partition: int,
|
56
|
+
) -> Tensor:
|
57
|
+
return torch.empty_like(q)
|
58
|
+
|
59
|
+
|
60
|
+
@torch.library.custom_op(
|
61
|
+
"rbln_custom_ops::paged_flash_attn_prefill",
|
62
|
+
mutates_args=(["kcache", "vcache"]),
|
63
|
+
)
|
64
|
+
def paged_flash_attn_prefill(
|
65
|
+
q: Tensor,
|
66
|
+
k: Tensor,
|
67
|
+
v: Tensor,
|
68
|
+
mask: Tensor,
|
69
|
+
kcache: Tensor,
|
70
|
+
vcache: Tensor,
|
71
|
+
seq: Tensor,
|
72
|
+
scale: Tensor,
|
73
|
+
block_table: Tensor,
|
74
|
+
block_size: int,
|
75
|
+
partition: int,
|
76
|
+
) -> Tensor:
|
77
|
+
"""Defines the computation pattern for fused flash attention with KV cache for prefill.
|
78
|
+
|
79
|
+
Returns a tensor with the same shape as q.
|
80
|
+
"""
|
81
|
+
return torch.empty_like(q)
|
82
|
+
|
83
|
+
|
84
|
+
@paged_flash_attn_prefill.register_fake
|
85
|
+
def paged_flash_attn_prefill_fake(
|
86
|
+
q: Tensor,
|
87
|
+
k: Tensor,
|
88
|
+
v: Tensor,
|
89
|
+
mask: Tensor,
|
90
|
+
kcache: Tensor,
|
91
|
+
vcache: Tensor,
|
92
|
+
seq: Tensor,
|
93
|
+
scale: Tensor,
|
94
|
+
block_table: Tensor,
|
95
|
+
block_size: int,
|
96
|
+
partition: int,
|
97
|
+
) -> Tensor:
|
98
|
+
return torch.empty_like(q)
|
99
|
+
|
100
|
+
|
101
|
+
@torch.library.custom_op(
|
102
|
+
"rbln_custom_ops::paged_flash_causal_attn_decode",
|
103
|
+
mutates_args=(["kcache", "vcache"]),
|
104
|
+
)
|
105
|
+
def paged_flash_causal_attn_decode(
|
106
|
+
q: Tensor,
|
107
|
+
k: Tensor,
|
108
|
+
v: Tensor,
|
109
|
+
kcache: Tensor,
|
110
|
+
vcache: Tensor,
|
111
|
+
seq: Tensor,
|
112
|
+
scale: Tensor,
|
113
|
+
block_table: Tensor,
|
114
|
+
block_size: int,
|
115
|
+
partition: int,
|
116
|
+
) -> Tensor:
|
117
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for decoding.
|
118
|
+
|
119
|
+
Returns a tensor with the same shape as q.
|
120
|
+
"""
|
121
|
+
return torch.empty_like(q)
|
122
|
+
|
123
|
+
|
124
|
+
@paged_flash_causal_attn_decode.register_fake
|
125
|
+
def paged_flash_causal_attn_decode_fake(
|
126
|
+
q: Tensor,
|
127
|
+
k: Tensor,
|
128
|
+
v: Tensor,
|
129
|
+
kcache: Tensor,
|
130
|
+
vcache: Tensor,
|
131
|
+
seq: Tensor,
|
132
|
+
scale: Tensor,
|
133
|
+
block_table: Tensor,
|
134
|
+
block_size: int,
|
135
|
+
partition: int,
|
136
|
+
) -> Tensor:
|
137
|
+
return torch.empty_like(q)
|
138
|
+
|
139
|
+
|
140
|
+
@torch.library.custom_op(
|
141
|
+
"rbln_custom_ops::paged_flash_causal_attn_prefill",
|
142
|
+
mutates_args=(["kcache", "vcache"]),
|
143
|
+
)
|
144
|
+
def paged_flash_causal_attn_prefill(
|
145
|
+
q: Tensor,
|
146
|
+
k: Tensor,
|
147
|
+
v: Tensor,
|
148
|
+
kcache: Tensor,
|
149
|
+
vcache: Tensor,
|
150
|
+
seq: Tensor,
|
151
|
+
scale: Tensor,
|
152
|
+
block_table: Tensor,
|
153
|
+
block_size: int,
|
154
|
+
partition: int,
|
155
|
+
) -> Tensor:
|
156
|
+
"""Defines the computation pattern for fused causal flash attention with KV cache for prefill.
|
157
|
+
|
158
|
+
Returns a tensor with the same shape as q.
|
159
|
+
"""
|
160
|
+
return torch.empty_like(q)
|
161
|
+
|
162
|
+
|
163
|
+
@paged_flash_causal_attn_prefill.register_fake
|
164
|
+
def paged_flash_causal_attn_prefill_fake(
|
165
|
+
q: Tensor,
|
166
|
+
k: Tensor,
|
167
|
+
v: Tensor,
|
168
|
+
kcache: Tensor,
|
169
|
+
vcache: Tensor,
|
170
|
+
seq: Tensor,
|
171
|
+
scale: Tensor,
|
172
|
+
block_table: Tensor,
|
173
|
+
block_size: int,
|
174
|
+
partition: int,
|
175
|
+
) -> Tensor:
|
176
|
+
return torch.empty_like(q)
|
@@ -12,49 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from functools import lru_cache
|
16
|
-
|
17
15
|
import torch
|
18
|
-
from
|
19
|
-
|
16
|
+
from torch import Tensor
|
20
17
|
|
21
|
-
if is_torch_greater_or_equal_than_2_4:
|
22
|
-
register_fake = torch.library.register_fake
|
23
|
-
else:
|
24
|
-
register_fake = torch.library.impl_abstract
|
25
18
|
|
26
|
-
|
27
|
-
|
28
|
-
def register_rbln_custom_cache_update():
|
19
|
+
@torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
|
20
|
+
def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
|
29
21
|
# Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
|
30
22
|
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
31
23
|
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
32
|
-
torch.
|
33
|
-
|
34
|
-
# Implementation of the "rbln_cache_update" operation for the CPU.
|
35
|
-
@torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
|
36
|
-
def rbln_cache_update_cpu(cache, state, position, axis):
|
37
|
-
assert position.dim() == 0
|
38
|
-
assert axis.dim() == 0
|
39
|
-
|
40
|
-
# Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
|
41
|
-
s = position # Start index for the update, specified by the position.
|
42
|
-
e = (
|
43
|
-
position + state.shape[axis]
|
44
|
-
) # End index is determined by adding the size of the state along the given axis.
|
45
|
-
|
46
|
-
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
47
|
-
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
48
|
-
cache.slice_scatter(state, dim=axis, start=s, end=e)
|
49
|
-
|
50
|
-
# 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
|
51
|
-
return torch.empty([256])
|
52
|
-
|
53
|
-
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
54
|
-
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
55
|
-
@register_fake("rbln_custom_ops::rbln_cache_update")
|
56
|
-
def rbln_cache_update_abstract(cache, state, position, axis):
|
57
|
-
# Return a tensor with the same shape as the input cache tensor.
|
58
|
-
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
59
|
-
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
60
|
-
return torch.empty([256])
|
24
|
+
return torch.empty_like(cache)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import Tensor
|
19
|
+
|
20
|
+
|
21
|
+
@torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
|
22
|
+
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
|
23
|
+
output_shape = list(input.shape[:-1])
|
24
|
+
output_shape += [weight.shape[0]]
|
25
|
+
return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
|
@@ -18,7 +18,15 @@ from transformers.utils import _LazyModule
|
|
18
18
|
|
19
19
|
|
20
20
|
_import_structure = {
|
21
|
-
"
|
21
|
+
"configuration_alias": [
|
22
|
+
"RBLNASTForAudioClassificationConfig",
|
23
|
+
"RBLNDistilBertForQuestionAnsweringConfig",
|
24
|
+
"RBLNResNetForImageClassificationConfig",
|
25
|
+
"RBLNXLMRobertaForSequenceClassificationConfig",
|
26
|
+
"RBLNRobertaForSequenceClassificationConfig",
|
27
|
+
"RBLNRobertaForMaskedLMConfig",
|
28
|
+
"RBLNViTForImageClassificationConfig",
|
29
|
+
],
|
22
30
|
"models": [
|
23
31
|
"RBLNAutoModel",
|
24
32
|
"RBLNAutoModelForAudioClassification",
|
@@ -33,29 +41,66 @@ _import_structure = {
|
|
33
41
|
"RBLNAutoModelForSpeechSeq2Seq",
|
34
42
|
"RBLNAutoModelForVision2Seq",
|
35
43
|
"RBLNBartForConditionalGeneration",
|
44
|
+
"RBLNBartForConditionalGenerationConfig",
|
36
45
|
"RBLNBartModel",
|
37
|
-
"
|
46
|
+
"RBLNBartModelConfig",
|
38
47
|
"RBLNBertForMaskedLM",
|
48
|
+
"RBLNBertForMaskedLMConfig",
|
39
49
|
"RBLNBertForQuestionAnswering",
|
50
|
+
"RBLNBertForQuestionAnsweringConfig",
|
51
|
+
"RBLNBertModel",
|
52
|
+
"RBLNBertModelConfig",
|
40
53
|
"RBLNCLIPTextModel",
|
54
|
+
"RBLNCLIPTextModelConfig",
|
41
55
|
"RBLNCLIPTextModelWithProjection",
|
56
|
+
"RBLNCLIPTextModelWithProjectionConfig",
|
42
57
|
"RBLNCLIPVisionModel",
|
58
|
+
"RBLNCLIPVisionModelConfig",
|
43
59
|
"RBLNCLIPVisionModelWithProjection",
|
60
|
+
"RBLNCLIPVisionModelWithProjectionConfig",
|
61
|
+
"RBLNDecoderOnlyModelForCausalLM",
|
62
|
+
"RBLNDecoderOnlyModelForCausalLMConfig",
|
44
63
|
"RBLNDPTForDepthEstimation",
|
64
|
+
"RBLNDPTForDepthEstimationConfig",
|
45
65
|
"RBLNExaoneForCausalLM",
|
66
|
+
"RBLNExaoneForCausalLMConfig",
|
46
67
|
"RBLNGemmaForCausalLM",
|
68
|
+
"RBLNGemmaForCausalLMConfig",
|
47
69
|
"RBLNGPT2LMHeadModel",
|
48
|
-
"
|
49
|
-
"
|
50
|
-
"
|
70
|
+
"RBLNGPT2LMHeadModelConfig",
|
71
|
+
"RBLNIdefics3VisionTransformer",
|
72
|
+
"RBLNIdefics3ForConditionalGeneration",
|
73
|
+
"RBLNIdefics3ForConditionalGenerationConfig",
|
74
|
+
"RBLNIdefics3VisionTransformerConfig",
|
51
75
|
"RBLNLlamaForCausalLM",
|
76
|
+
"RBLNLlamaForCausalLMConfig",
|
77
|
+
"RBLNLlavaNextForConditionalGeneration",
|
78
|
+
"RBLNLlavaNextForConditionalGenerationConfig",
|
79
|
+
"RBLNMidmLMHeadModel",
|
80
|
+
"RBLNMidmLMHeadModelConfig",
|
81
|
+
"RBLNMistralForCausalLM",
|
82
|
+
"RBLNMistralForCausalLMConfig",
|
52
83
|
"RBLNPhiForCausalLM",
|
84
|
+
"RBLNPhiForCausalLMConfig",
|
85
|
+
"RBLNQwen2ForCausalLM",
|
86
|
+
"RBLNQwen2ForCausalLMConfig",
|
87
|
+
"RBLNQwen2_5_VisionTransformerPretrainedModel",
|
88
|
+
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
89
|
+
"RBLNQwen2_5_VLForConditionalGeneration",
|
90
|
+
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
53
91
|
"RBLNT5EncoderModel",
|
92
|
+
"RBLNT5EncoderModelConfig",
|
54
93
|
"RBLNT5ForConditionalGeneration",
|
94
|
+
"RBLNT5ForConditionalGenerationConfig",
|
95
|
+
"RBLNWav2Vec2ForCTC",
|
96
|
+
"RBLNWav2Vec2ForCTCConfig",
|
97
|
+
"RBLNWhisperForConditionalGeneration",
|
98
|
+
"RBLNWhisperForConditionalGenerationConfig",
|
99
|
+
"RBLNTimeSeriesTransformerForPrediction",
|
100
|
+
"RBLNTimeSeriesTransformerForPredictionConfig",
|
55
101
|
"RBLNLlavaNextForConditionalGeneration",
|
56
|
-
"RBLNMidmLMHeadModel",
|
57
102
|
"RBLNXLMRobertaModel",
|
58
|
-
"
|
103
|
+
"RBLNXLMRobertaModelConfig",
|
59
104
|
],
|
60
105
|
"modeling_alias": [
|
61
106
|
"RBLNASTForAudioClassification",
|
@@ -69,7 +114,15 @@ _import_structure = {
|
|
69
114
|
}
|
70
115
|
|
71
116
|
if TYPE_CHECKING:
|
72
|
-
from .
|
117
|
+
from .configuration_alias import (
|
118
|
+
RBLNASTForAudioClassificationConfig,
|
119
|
+
RBLNDistilBertForQuestionAnsweringConfig,
|
120
|
+
RBLNResNetForImageClassificationConfig,
|
121
|
+
RBLNRobertaForMaskedLMConfig,
|
122
|
+
RBLNRobertaForSequenceClassificationConfig,
|
123
|
+
RBLNViTForImageClassificationConfig,
|
124
|
+
RBLNXLMRobertaForSequenceClassificationConfig,
|
125
|
+
)
|
73
126
|
from .modeling_alias import (
|
74
127
|
RBLNASTForAudioClassification,
|
75
128
|
RBLNDistilBertForQuestionAnswering,
|
@@ -93,29 +146,65 @@ if TYPE_CHECKING:
|
|
93
146
|
RBLNAutoModelForSpeechSeq2Seq,
|
94
147
|
RBLNAutoModelForVision2Seq,
|
95
148
|
RBLNBartForConditionalGeneration,
|
149
|
+
RBLNBartForConditionalGenerationConfig,
|
96
150
|
RBLNBartModel,
|
151
|
+
RBLNBartModelConfig,
|
97
152
|
RBLNBertForMaskedLM,
|
153
|
+
RBLNBertForMaskedLMConfig,
|
98
154
|
RBLNBertForQuestionAnswering,
|
155
|
+
RBLNBertForQuestionAnsweringConfig,
|
99
156
|
RBLNBertModel,
|
157
|
+
RBLNBertModelConfig,
|
100
158
|
RBLNCLIPTextModel,
|
159
|
+
RBLNCLIPTextModelConfig,
|
101
160
|
RBLNCLIPTextModelWithProjection,
|
161
|
+
RBLNCLIPTextModelWithProjectionConfig,
|
102
162
|
RBLNCLIPVisionModel,
|
163
|
+
RBLNCLIPVisionModelConfig,
|
103
164
|
RBLNCLIPVisionModelWithProjection,
|
165
|
+
RBLNCLIPVisionModelWithProjectionConfig,
|
166
|
+
RBLNDecoderOnlyModelForCausalLM,
|
167
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
104
168
|
RBLNDPTForDepthEstimation,
|
169
|
+
RBLNDPTForDepthEstimationConfig,
|
105
170
|
RBLNExaoneForCausalLM,
|
171
|
+
RBLNExaoneForCausalLMConfig,
|
106
172
|
RBLNGemmaForCausalLM,
|
173
|
+
RBLNGemmaForCausalLMConfig,
|
107
174
|
RBLNGPT2LMHeadModel,
|
175
|
+
RBLNGPT2LMHeadModelConfig,
|
176
|
+
RBLNIdefics3ForConditionalGeneration,
|
177
|
+
RBLNIdefics3ForConditionalGenerationConfig,
|
178
|
+
RBLNIdefics3VisionTransformer,
|
179
|
+
RBLNIdefics3VisionTransformerConfig,
|
108
180
|
RBLNLlamaForCausalLM,
|
181
|
+
RBLNLlamaForCausalLMConfig,
|
109
182
|
RBLNLlavaNextForConditionalGeneration,
|
183
|
+
RBLNLlavaNextForConditionalGenerationConfig,
|
110
184
|
RBLNMidmLMHeadModel,
|
185
|
+
RBLNMidmLMHeadModelConfig,
|
111
186
|
RBLNMistralForCausalLM,
|
187
|
+
RBLNMistralForCausalLMConfig,
|
112
188
|
RBLNPhiForCausalLM,
|
189
|
+
RBLNPhiForCausalLMConfig,
|
190
|
+
RBLNQwen2_5_VisionTransformerPretrainedModel,
|
191
|
+
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
192
|
+
RBLNQwen2_5_VLForConditionalGeneration,
|
193
|
+
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
113
194
|
RBLNQwen2ForCausalLM,
|
195
|
+
RBLNQwen2ForCausalLMConfig,
|
114
196
|
RBLNT5EncoderModel,
|
197
|
+
RBLNT5EncoderModelConfig,
|
115
198
|
RBLNT5ForConditionalGeneration,
|
199
|
+
RBLNT5ForConditionalGenerationConfig,
|
200
|
+
RBLNTimeSeriesTransformerForPrediction,
|
201
|
+
RBLNTimeSeriesTransformerForPredictionConfig,
|
116
202
|
RBLNWav2Vec2ForCTC,
|
203
|
+
RBLNWav2Vec2ForCTCConfig,
|
117
204
|
RBLNWhisperForConditionalGeneration,
|
205
|
+
RBLNWhisperForConditionalGenerationConfig,
|
118
206
|
RBLNXLMRobertaModel,
|
207
|
+
RBLNXLMRobertaModelConfig,
|
119
208
|
)
|
120
209
|
else:
|
121
210
|
import sys
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from .configuration_generic import (
|
16
|
+
RBLNModelForAudioClassificationConfig,
|
17
|
+
RBLNModelForImageClassificationConfig,
|
18
|
+
RBLNModelForMaskedLMConfig,
|
19
|
+
RBLNModelForQuestionAnsweringConfig,
|
20
|
+
RBLNModelForSequenceClassificationConfig,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
class RBLNASTForAudioClassificationConfig(RBLNModelForAudioClassificationConfig):
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
37
|
+
pass
|
38
|
+
|
39
|
+
|
40
|
+
class RBLNRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
48
|
+
class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
|
49
|
+
pass
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, Optional, Tuple, Union
|
16
|
+
|
17
|
+
from ..configuration_utils import RBLNModelConfig
|
18
|
+
|
19
|
+
|
20
|
+
class _RBLNTransformerEncoderConfig(RBLNModelConfig):
|
21
|
+
rbln_model_input_names: Optional[List[str]] = None
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
max_seq_len: Optional[int] = None,
|
26
|
+
batch_size: Optional[int] = None,
|
27
|
+
model_input_names: Optional[List[str]] = None,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Args:
|
32
|
+
max_seq_len (Optional[int]): Maximum sequence length supported by the model.
|
33
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
34
|
+
model_input_names (Optional[List[str]]): Names of the input tensors for the model.
|
35
|
+
Defaults to class-specific rbln_model_input_names if not provided.
|
36
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If batch_size is not a positive integer.
|
40
|
+
"""
|
41
|
+
super().__init__(**kwargs)
|
42
|
+
self.max_seq_len = max_seq_len
|
43
|
+
self.batch_size = batch_size or 1
|
44
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
45
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
46
|
+
|
47
|
+
self.model_input_names = model_input_names or self.rbln_model_input_names
|
48
|
+
|
49
|
+
|
50
|
+
class _RBLNImageModelConfig(RBLNModelConfig):
|
51
|
+
def __init__(
|
52
|
+
self, image_size: Optional[Union[int, Tuple[int, int]]] = None, batch_size: Optional[int] = None, **kwargs
|
53
|
+
):
|
54
|
+
"""
|
55
|
+
Args:
|
56
|
+
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
57
|
+
Can be an integer for square images or a tuple (height, width).
|
58
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
59
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ValueError: If batch_size is not a positive integer.
|
63
|
+
"""
|
64
|
+
super().__init__(**kwargs)
|
65
|
+
self.image_size = image_size
|
66
|
+
self.batch_size = batch_size or 1
|
67
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
68
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
69
|
+
|
70
|
+
@property
|
71
|
+
def image_width(self):
|
72
|
+
if isinstance(self.image_size, int):
|
73
|
+
return self.image_size
|
74
|
+
elif isinstance(self.image_size, (list, tuple)):
|
75
|
+
return self.image_size[1]
|
76
|
+
else:
|
77
|
+
return self.image_size["width"]
|
78
|
+
|
79
|
+
@property
|
80
|
+
def image_height(self):
|
81
|
+
if isinstance(self.image_size, int):
|
82
|
+
return self.image_size
|
83
|
+
elif isinstance(self.image_size, (list, tuple)):
|
84
|
+
return self.image_size[0]
|
85
|
+
else:
|
86
|
+
return self.image_size["height"]
|
87
|
+
|
88
|
+
|
89
|
+
class RBLNModelForQuestionAnsweringConfig(_RBLNTransformerEncoderConfig):
|
90
|
+
pass
|
91
|
+
|
92
|
+
|
93
|
+
class RBLNModelForSequenceClassificationConfig(_RBLNTransformerEncoderConfig):
|
94
|
+
pass
|
95
|
+
|
96
|
+
|
97
|
+
class RBLNModelForMaskedLMConfig(_RBLNTransformerEncoderConfig):
|
98
|
+
pass
|
99
|
+
|
100
|
+
|
101
|
+
class RBLNModelForTextEncodingConfig(_RBLNTransformerEncoderConfig):
|
102
|
+
pass
|
103
|
+
|
104
|
+
|
105
|
+
# FIXME : Appropriate name ?
|
106
|
+
class RBLNTransformerEncoderForFeatureExtractionConfig(_RBLNTransformerEncoderConfig):
|
107
|
+
pass
|
108
|
+
|
109
|
+
|
110
|
+
class RBLNModelForImageClassificationConfig(_RBLNImageModelConfig):
|
111
|
+
pass
|
112
|
+
|
113
|
+
|
114
|
+
class RBLNModelForDepthEstimationConfig(_RBLNImageModelConfig):
|
115
|
+
pass
|
116
|
+
|
117
|
+
|
118
|
+
class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
batch_size: Optional[int] = None,
|
122
|
+
max_length: Optional[int] = None,
|
123
|
+
num_mel_bins: Optional[int] = None,
|
124
|
+
**kwargs,
|
125
|
+
):
|
126
|
+
"""
|
127
|
+
Args:
|
128
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
129
|
+
max_length (Optional[int]): Maximum length of the audio input in time dimension.
|
130
|
+
num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
|
131
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
132
|
+
|
133
|
+
Raises:
|
134
|
+
ValueError: If batch_size is not a positive integer.
|
135
|
+
"""
|
136
|
+
super().__init__(**kwargs)
|
137
|
+
self.batch_size = batch_size or 1
|
138
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
139
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
140
|
+
|
141
|
+
self.max_length = max_length
|
142
|
+
self.num_mel_bins = num_mel_bins
|