optimum-rbln 0.7.3.post1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +11 -86
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -118
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +23 -151
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post1.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post1.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -12,71 +12,165 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
-
17
15
  import torch
18
- from packaging import version
19
-
20
-
21
- if version.parse(torch.__version__) > version.parse("2.4.0"):
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
-
26
-
27
- @lru_cache
28
- def register_rbln_custom_paged_flash_attention():
29
- torch.library.define(
30
- "rbln_custom_ops::paged_flash_attn_decode",
31
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
32
- )
33
-
34
- @torch.library.impl("rbln_custom_ops::paged_flash_attn_decode", "cpu")
35
- def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
36
- return q
37
-
38
- @register_fake("rbln_custom_ops::paged_flash_attn_decode")
39
- def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
40
- return q
41
-
42
- torch.library.define(
43
- "rbln_custom_ops::paged_flash_attn_prefill",
44
- "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
45
- )
46
-
47
- @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
48
- def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, seq, scale, block_table, block_size, partition):
49
- return q
50
-
51
- @register_fake("rbln_custom_ops::paged_flash_attn_prefill")
52
- def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, seq, scale, block_table, block_size, partition):
53
- return q
54
-
55
-
56
- @lru_cache
57
- def register_rbln_custom_paged_flash_causal_attention():
58
- torch.library.define(
59
- "rbln_custom_ops::paged_flash_causal_attn_decode",
60
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
61
- )
62
-
63
- @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_decode", "cpu")
64
- def flash_attn_decode_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
65
- return q
66
-
67
- @register_fake("rbln_custom_ops::paged_flash_causal_attn_decode")
68
- def flash_attn_decode_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
69
- return q
70
-
71
- torch.library.define(
72
- "rbln_custom_ops::paged_flash_causal_attn_prefill",
73
- "(Tensor x, Tensor y, Tensor z, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f, int g) -> Tensor",
74
- )
75
-
76
- @torch.library.impl("rbln_custom_ops::paged_flash_causal_attn_prefill", "cpu")
77
- def flash_attn_prefill_cpu(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
78
- return q
79
-
80
- @register_fake("rbln_custom_ops::paged_flash_causal_attn_prefill")
81
- def flash_attn_prefill_abstract(q, k, v, kcache, vcache, seq, scale, block_table, block_size, partition):
82
- return q
16
+ from torch import Tensor
17
+
18
+
19
+ @torch.library.custom_op(
20
+ "rbln_custom_ops::paged_flash_attn_decode",
21
+ mutates_args=(["kcache", "vcache"]),
22
+ )
23
+ def paged_flash_attn_decode(
24
+ q: Tensor,
25
+ k: Tensor,
26
+ v: Tensor,
27
+ mask: Tensor,
28
+ kcache: Tensor,
29
+ vcache: Tensor,
30
+ seq: Tensor,
31
+ scale: Tensor,
32
+ block_table: Tensor,
33
+ block_size: int,
34
+ partition: int,
35
+ ) -> Tensor:
36
+ """Defines the computation pattern for fused flash attention with KV cache for decoding.
37
+
38
+ Returns a tensor with the same shape as q.
39
+ """
40
+ return torch.empty_like(q)
41
+
42
+
43
+ @paged_flash_attn_decode.register_fake
44
+ def paged_flash_attn_decode_fake(
45
+ q: Tensor,
46
+ k: Tensor,
47
+ v: Tensor,
48
+ mask: Tensor,
49
+ kcache: Tensor,
50
+ vcache: Tensor,
51
+ seq: Tensor,
52
+ scale: Tensor,
53
+ block_table: Tensor,
54
+ block_size: int,
55
+ partition: int,
56
+ ) -> Tensor:
57
+ return torch.empty_like(q)
58
+
59
+
60
+ @torch.library.custom_op(
61
+ "rbln_custom_ops::paged_flash_attn_prefill",
62
+ mutates_args=(["kcache", "vcache"]),
63
+ )
64
+ def paged_flash_attn_prefill(
65
+ q: Tensor,
66
+ k: Tensor,
67
+ v: Tensor,
68
+ mask: Tensor,
69
+ kcache: Tensor,
70
+ vcache: Tensor,
71
+ seq: Tensor,
72
+ scale: Tensor,
73
+ block_table: Tensor,
74
+ block_size: int,
75
+ partition: int,
76
+ ) -> Tensor:
77
+ """Defines the computation pattern for fused flash attention with KV cache for prefill.
78
+
79
+ Returns a tensor with the same shape as q.
80
+ """
81
+ return torch.empty_like(q)
82
+
83
+
84
+ @paged_flash_attn_prefill.register_fake
85
+ def paged_flash_attn_prefill_fake(
86
+ q: Tensor,
87
+ k: Tensor,
88
+ v: Tensor,
89
+ mask: Tensor,
90
+ kcache: Tensor,
91
+ vcache: Tensor,
92
+ seq: Tensor,
93
+ scale: Tensor,
94
+ block_table: Tensor,
95
+ block_size: int,
96
+ partition: int,
97
+ ) -> Tensor:
98
+ return torch.empty_like(q)
99
+
100
+
101
+ @torch.library.custom_op(
102
+ "rbln_custom_ops::paged_flash_causal_attn_decode",
103
+ mutates_args=(["kcache", "vcache"]),
104
+ )
105
+ def paged_flash_causal_attn_decode(
106
+ q: Tensor,
107
+ k: Tensor,
108
+ v: Tensor,
109
+ kcache: Tensor,
110
+ vcache: Tensor,
111
+ seq: Tensor,
112
+ scale: Tensor,
113
+ block_table: Tensor,
114
+ block_size: int,
115
+ partition: int,
116
+ ) -> Tensor:
117
+ """Defines the computation pattern for fused causal flash attention with KV cache for decoding.
118
+
119
+ Returns a tensor with the same shape as q.
120
+ """
121
+ return torch.empty_like(q)
122
+
123
+
124
+ @paged_flash_causal_attn_decode.register_fake
125
+ def paged_flash_causal_attn_decode_fake(
126
+ q: Tensor,
127
+ k: Tensor,
128
+ v: Tensor,
129
+ kcache: Tensor,
130
+ vcache: Tensor,
131
+ seq: Tensor,
132
+ scale: Tensor,
133
+ block_table: Tensor,
134
+ block_size: int,
135
+ partition: int,
136
+ ) -> Tensor:
137
+ return torch.empty_like(q)
138
+
139
+
140
+ @torch.library.custom_op(
141
+ "rbln_custom_ops::paged_flash_causal_attn_prefill",
142
+ mutates_args=(["kcache", "vcache"]),
143
+ )
144
+ def paged_flash_causal_attn_prefill(
145
+ q: Tensor,
146
+ k: Tensor,
147
+ v: Tensor,
148
+ kcache: Tensor,
149
+ vcache: Tensor,
150
+ seq: Tensor,
151
+ scale: Tensor,
152
+ block_table: Tensor,
153
+ block_size: int,
154
+ partition: int,
155
+ ) -> Tensor:
156
+ """Defines the computation pattern for fused causal flash attention with KV cache for prefill.
157
+
158
+ Returns a tensor with the same shape as q.
159
+ """
160
+ return torch.empty_like(q)
161
+
162
+
163
+ @paged_flash_causal_attn_prefill.register_fake
164
+ def paged_flash_causal_attn_prefill_fake(
165
+ q: Tensor,
166
+ k: Tensor,
167
+ v: Tensor,
168
+ kcache: Tensor,
169
+ vcache: Tensor,
170
+ seq: Tensor,
171
+ scale: Tensor,
172
+ block_table: Tensor,
173
+ block_size: int,
174
+ partition: int,
175
+ ) -> Tensor:
176
+ return torch.empty_like(q)
@@ -12,49 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from functools import lru_cache
16
-
17
15
  import torch
18
- from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
19
-
16
+ from torch import Tensor
20
17
 
21
- if is_torch_greater_or_equal_than_2_4:
22
- register_fake = torch.library.register_fake
23
- else:
24
- register_fake = torch.library.impl_abstract
25
18
 
26
-
27
- @lru_cache
28
- def register_rbln_custom_cache_update():
19
+ @torch.library.custom_op("rbln_custom_ops::rbln_cache_update", mutates_args=(["cache"]))
20
+ def rbln_cache_update(cache: Tensor, state: Tensor, position: Tensor, axis: Tensor) -> Tensor:
29
21
  # Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
30
22
  # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
31
23
  # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
32
- torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
33
-
34
- # Implementation of the "rbln_cache_update" operation for the CPU.
35
- @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
36
- def rbln_cache_update_cpu(cache, state, position, axis):
37
- assert position.dim() == 0
38
- assert axis.dim() == 0
39
-
40
- # Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
41
- s = position # Start index for the update, specified by the position.
42
- e = (
43
- position + state.shape[axis]
44
- ) # End index is determined by adding the size of the state along the given axis.
45
-
46
- # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
47
- # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
48
- cache.slice_scatter(state, dim=axis, start=s, end=e)
49
-
50
- # 'rbln_cache_update' is an in-place operation that isn't tracked in JIT trace, so a dummy output was added to the return value.
51
- return torch.empty([256])
52
-
53
- # Register a "fake" implementation of the "rbln_cache_update" operation.
54
- # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
55
- @register_fake("rbln_custom_ops::rbln_cache_update")
56
- def rbln_cache_update_abstract(cache, state, position, axis):
57
- # Return a tensor with the same shape as the input cache tensor.
58
- # This is a placeholder for the abstract implementation and does not perform any actual computation.
59
- # Like the actual implementation, the abstraction assumes in-place device-side updates.
60
- return torch.empty([256])
24
+ return torch.empty_like(cache)
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import Tensor
19
+
20
+
21
+ @torch.library.custom_op("rbln_custom_ops::linear", mutates_args=())
22
+ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
23
+ output_shape = list(input.shape[:-1])
24
+ output_shape += [weight.shape[0]]
25
+ return torch.empty(size=output_shape, dtype=input.dtype, device=input.device, requires_grad=input.requires_grad)
@@ -18,7 +18,15 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
- "cache_utils": ["RebelDynamicCache"],
21
+ "configuration_alias": [
22
+ "RBLNASTForAudioClassificationConfig",
23
+ "RBLNDistilBertForQuestionAnsweringConfig",
24
+ "RBLNResNetForImageClassificationConfig",
25
+ "RBLNXLMRobertaForSequenceClassificationConfig",
26
+ "RBLNRobertaForSequenceClassificationConfig",
27
+ "RBLNRobertaForMaskedLMConfig",
28
+ "RBLNViTForImageClassificationConfig",
29
+ ],
22
30
  "models": [
23
31
  "RBLNAutoModel",
24
32
  "RBLNAutoModelForAudioClassification",
@@ -33,29 +41,66 @@ _import_structure = {
33
41
  "RBLNAutoModelForSpeechSeq2Seq",
34
42
  "RBLNAutoModelForVision2Seq",
35
43
  "RBLNBartForConditionalGeneration",
44
+ "RBLNBartForConditionalGenerationConfig",
36
45
  "RBLNBartModel",
37
- "RBLNBertModel",
46
+ "RBLNBartModelConfig",
38
47
  "RBLNBertForMaskedLM",
48
+ "RBLNBertForMaskedLMConfig",
39
49
  "RBLNBertForQuestionAnswering",
50
+ "RBLNBertForQuestionAnsweringConfig",
51
+ "RBLNBertModel",
52
+ "RBLNBertModelConfig",
40
53
  "RBLNCLIPTextModel",
54
+ "RBLNCLIPTextModelConfig",
41
55
  "RBLNCLIPTextModelWithProjection",
56
+ "RBLNCLIPTextModelWithProjectionConfig",
42
57
  "RBLNCLIPVisionModel",
58
+ "RBLNCLIPVisionModelConfig",
43
59
  "RBLNCLIPVisionModelWithProjection",
60
+ "RBLNCLIPVisionModelWithProjectionConfig",
61
+ "RBLNDecoderOnlyModelForCausalLM",
62
+ "RBLNDecoderOnlyModelForCausalLMConfig",
44
63
  "RBLNDPTForDepthEstimation",
64
+ "RBLNDPTForDepthEstimationConfig",
45
65
  "RBLNExaoneForCausalLM",
66
+ "RBLNExaoneForCausalLMConfig",
46
67
  "RBLNGemmaForCausalLM",
68
+ "RBLNGemmaForCausalLMConfig",
47
69
  "RBLNGPT2LMHeadModel",
48
- "RBLNQwen2ForCausalLM",
49
- "RBLNWav2Vec2ForCTC",
50
- "RBLNWhisperForConditionalGeneration",
70
+ "RBLNGPT2LMHeadModelConfig",
71
+ "RBLNIdefics3VisionTransformer",
72
+ "RBLNIdefics3ForConditionalGeneration",
73
+ "RBLNIdefics3ForConditionalGenerationConfig",
74
+ "RBLNIdefics3VisionTransformerConfig",
51
75
  "RBLNLlamaForCausalLM",
76
+ "RBLNLlamaForCausalLMConfig",
77
+ "RBLNLlavaNextForConditionalGeneration",
78
+ "RBLNLlavaNextForConditionalGenerationConfig",
79
+ "RBLNMidmLMHeadModel",
80
+ "RBLNMidmLMHeadModelConfig",
81
+ "RBLNMistralForCausalLM",
82
+ "RBLNMistralForCausalLMConfig",
52
83
  "RBLNPhiForCausalLM",
84
+ "RBLNPhiForCausalLMConfig",
85
+ "RBLNQwen2ForCausalLM",
86
+ "RBLNQwen2ForCausalLMConfig",
87
+ "RBLNQwen2_5_VisionTransformerPretrainedModel",
88
+ "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
89
+ "RBLNQwen2_5_VLForConditionalGeneration",
90
+ "RBLNQwen2_5_VLForConditionalGenerationConfig",
53
91
  "RBLNT5EncoderModel",
92
+ "RBLNT5EncoderModelConfig",
54
93
  "RBLNT5ForConditionalGeneration",
94
+ "RBLNT5ForConditionalGenerationConfig",
95
+ "RBLNWav2Vec2ForCTC",
96
+ "RBLNWav2Vec2ForCTCConfig",
97
+ "RBLNWhisperForConditionalGeneration",
98
+ "RBLNWhisperForConditionalGenerationConfig",
99
+ "RBLNTimeSeriesTransformerForPrediction",
100
+ "RBLNTimeSeriesTransformerForPredictionConfig",
55
101
  "RBLNLlavaNextForConditionalGeneration",
56
- "RBLNMidmLMHeadModel",
57
102
  "RBLNXLMRobertaModel",
58
- "RBLNMistralForCausalLM",
103
+ "RBLNXLMRobertaModelConfig",
59
104
  ],
60
105
  "modeling_alias": [
61
106
  "RBLNASTForAudioClassification",
@@ -69,7 +114,15 @@ _import_structure = {
69
114
  }
70
115
 
71
116
  if TYPE_CHECKING:
72
- from .cache_utils import RebelDynamicCache
117
+ from .configuration_alias import (
118
+ RBLNASTForAudioClassificationConfig,
119
+ RBLNDistilBertForQuestionAnsweringConfig,
120
+ RBLNResNetForImageClassificationConfig,
121
+ RBLNRobertaForMaskedLMConfig,
122
+ RBLNRobertaForSequenceClassificationConfig,
123
+ RBLNViTForImageClassificationConfig,
124
+ RBLNXLMRobertaForSequenceClassificationConfig,
125
+ )
73
126
  from .modeling_alias import (
74
127
  RBLNASTForAudioClassification,
75
128
  RBLNDistilBertForQuestionAnswering,
@@ -93,29 +146,65 @@ if TYPE_CHECKING:
93
146
  RBLNAutoModelForSpeechSeq2Seq,
94
147
  RBLNAutoModelForVision2Seq,
95
148
  RBLNBartForConditionalGeneration,
149
+ RBLNBartForConditionalGenerationConfig,
96
150
  RBLNBartModel,
151
+ RBLNBartModelConfig,
97
152
  RBLNBertForMaskedLM,
153
+ RBLNBertForMaskedLMConfig,
98
154
  RBLNBertForQuestionAnswering,
155
+ RBLNBertForQuestionAnsweringConfig,
99
156
  RBLNBertModel,
157
+ RBLNBertModelConfig,
100
158
  RBLNCLIPTextModel,
159
+ RBLNCLIPTextModelConfig,
101
160
  RBLNCLIPTextModelWithProjection,
161
+ RBLNCLIPTextModelWithProjectionConfig,
102
162
  RBLNCLIPVisionModel,
163
+ RBLNCLIPVisionModelConfig,
103
164
  RBLNCLIPVisionModelWithProjection,
165
+ RBLNCLIPVisionModelWithProjectionConfig,
166
+ RBLNDecoderOnlyModelForCausalLM,
167
+ RBLNDecoderOnlyModelForCausalLMConfig,
104
168
  RBLNDPTForDepthEstimation,
169
+ RBLNDPTForDepthEstimationConfig,
105
170
  RBLNExaoneForCausalLM,
171
+ RBLNExaoneForCausalLMConfig,
106
172
  RBLNGemmaForCausalLM,
173
+ RBLNGemmaForCausalLMConfig,
107
174
  RBLNGPT2LMHeadModel,
175
+ RBLNGPT2LMHeadModelConfig,
176
+ RBLNIdefics3ForConditionalGeneration,
177
+ RBLNIdefics3ForConditionalGenerationConfig,
178
+ RBLNIdefics3VisionTransformer,
179
+ RBLNIdefics3VisionTransformerConfig,
108
180
  RBLNLlamaForCausalLM,
181
+ RBLNLlamaForCausalLMConfig,
109
182
  RBLNLlavaNextForConditionalGeneration,
183
+ RBLNLlavaNextForConditionalGenerationConfig,
110
184
  RBLNMidmLMHeadModel,
185
+ RBLNMidmLMHeadModelConfig,
111
186
  RBLNMistralForCausalLM,
187
+ RBLNMistralForCausalLMConfig,
112
188
  RBLNPhiForCausalLM,
189
+ RBLNPhiForCausalLMConfig,
190
+ RBLNQwen2_5_VisionTransformerPretrainedModel,
191
+ RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
192
+ RBLNQwen2_5_VLForConditionalGeneration,
193
+ RBLNQwen2_5_VLForConditionalGenerationConfig,
113
194
  RBLNQwen2ForCausalLM,
195
+ RBLNQwen2ForCausalLMConfig,
114
196
  RBLNT5EncoderModel,
197
+ RBLNT5EncoderModelConfig,
115
198
  RBLNT5ForConditionalGeneration,
199
+ RBLNT5ForConditionalGenerationConfig,
200
+ RBLNTimeSeriesTransformerForPrediction,
201
+ RBLNTimeSeriesTransformerForPredictionConfig,
116
202
  RBLNWav2Vec2ForCTC,
203
+ RBLNWav2Vec2ForCTCConfig,
117
204
  RBLNWhisperForConditionalGeneration,
205
+ RBLNWhisperForConditionalGenerationConfig,
118
206
  RBLNXLMRobertaModel,
207
+ RBLNXLMRobertaModelConfig,
119
208
  )
120
209
  else:
121
210
  import sys
@@ -0,0 +1,49 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_generic import (
16
+ RBLNModelForAudioClassificationConfig,
17
+ RBLNModelForImageClassificationConfig,
18
+ RBLNModelForMaskedLMConfig,
19
+ RBLNModelForQuestionAnsweringConfig,
20
+ RBLNModelForSequenceClassificationConfig,
21
+ )
22
+
23
+
24
+ class RBLNASTForAudioClassificationConfig(RBLNModelForAudioClassificationConfig):
25
+ pass
26
+
27
+
28
+ class RBLNDistilBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
29
+ pass
30
+
31
+
32
+ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConfig):
33
+ pass
34
+
35
+
36
+ class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
37
+ pass
38
+
39
+
40
+ class RBLNRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
41
+ pass
42
+
43
+
44
+ class RBLNRobertaForMaskedLMConfig(RBLNModelForMaskedLMConfig):
45
+ pass
46
+
47
+
48
+ class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
49
+ pass
@@ -0,0 +1,142 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ from ..configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class _RBLNTransformerEncoderConfig(RBLNModelConfig):
21
+ rbln_model_input_names: Optional[List[str]] = None
22
+
23
+ def __init__(
24
+ self,
25
+ max_seq_len: Optional[int] = None,
26
+ batch_size: Optional[int] = None,
27
+ model_input_names: Optional[List[str]] = None,
28
+ **kwargs,
29
+ ):
30
+ """
31
+ Args:
32
+ max_seq_len (Optional[int]): Maximum sequence length supported by the model.
33
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
34
+ model_input_names (Optional[List[str]]): Names of the input tensors for the model.
35
+ Defaults to class-specific rbln_model_input_names if not provided.
36
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.max_seq_len = max_seq_len
43
+ self.batch_size = batch_size or 1
44
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
45
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
46
+
47
+ self.model_input_names = model_input_names or self.rbln_model_input_names
48
+
49
+
50
+ class _RBLNImageModelConfig(RBLNModelConfig):
51
+ def __init__(
52
+ self, image_size: Optional[Union[int, Tuple[int, int]]] = None, batch_size: Optional[int] = None, **kwargs
53
+ ):
54
+ """
55
+ Args:
56
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
57
+ Can be an integer for square images or a tuple (height, width).
58
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
59
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
60
+
61
+ Raises:
62
+ ValueError: If batch_size is not a positive integer.
63
+ """
64
+ super().__init__(**kwargs)
65
+ self.image_size = image_size
66
+ self.batch_size = batch_size or 1
67
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
68
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
69
+
70
+ @property
71
+ def image_width(self):
72
+ if isinstance(self.image_size, int):
73
+ return self.image_size
74
+ elif isinstance(self.image_size, (list, tuple)):
75
+ return self.image_size[1]
76
+ else:
77
+ return self.image_size["width"]
78
+
79
+ @property
80
+ def image_height(self):
81
+ if isinstance(self.image_size, int):
82
+ return self.image_size
83
+ elif isinstance(self.image_size, (list, tuple)):
84
+ return self.image_size[0]
85
+ else:
86
+ return self.image_size["height"]
87
+
88
+
89
+ class RBLNModelForQuestionAnsweringConfig(_RBLNTransformerEncoderConfig):
90
+ pass
91
+
92
+
93
+ class RBLNModelForSequenceClassificationConfig(_RBLNTransformerEncoderConfig):
94
+ pass
95
+
96
+
97
+ class RBLNModelForMaskedLMConfig(_RBLNTransformerEncoderConfig):
98
+ pass
99
+
100
+
101
+ class RBLNModelForTextEncodingConfig(_RBLNTransformerEncoderConfig):
102
+ pass
103
+
104
+
105
+ # FIXME : Appropriate name ?
106
+ class RBLNTransformerEncoderForFeatureExtractionConfig(_RBLNTransformerEncoderConfig):
107
+ pass
108
+
109
+
110
+ class RBLNModelForImageClassificationConfig(_RBLNImageModelConfig):
111
+ pass
112
+
113
+
114
+ class RBLNModelForDepthEstimationConfig(_RBLNImageModelConfig):
115
+ pass
116
+
117
+
118
+ class RBLNModelForAudioClassificationConfig(RBLNModelConfig):
119
+ def __init__(
120
+ self,
121
+ batch_size: Optional[int] = None,
122
+ max_length: Optional[int] = None,
123
+ num_mel_bins: Optional[int] = None,
124
+ **kwargs,
125
+ ):
126
+ """
127
+ Args:
128
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
129
+ max_length (Optional[int]): Maximum length of the audio input in time dimension.
130
+ num_mel_bins (Optional[int]): Number of Mel frequency bins for audio processing.
131
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
132
+
133
+ Raises:
134
+ ValueError: If batch_size is not a positive integer.
135
+ """
136
+ super().__init__(**kwargs)
137
+ self.batch_size = batch_size or 1
138
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
139
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
140
+
141
+ self.max_length = max_length
142
+ self.num_mel_bins = num_mel_bins