optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. optimum/rbln/__init__.py +26 -33
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/diffusers/__init__.py +4 -0
  4. optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
  5. optimum/rbln/diffusers/models/__init__.py +2 -0
  6. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
  8. optimum/rbln/diffusers/models/controlnet.py +1 -1
  9. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  10. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
  11. optimum/rbln/diffusers/pipelines/__init__.py +1 -0
  12. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
  13. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
  17. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
  20. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
  21. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
  27. optimum/rbln/modeling.py +13 -347
  28. optimum/rbln/modeling_base.py +24 -4
  29. optimum/rbln/modeling_config.py +31 -7
  30. optimum/rbln/ops/__init__.py +26 -0
  31. optimum/rbln/ops/attn.py +221 -0
  32. optimum/rbln/ops/flash_attn.py +70 -0
  33. optimum/rbln/ops/kv_cache_update.py +69 -0
  34. optimum/rbln/transformers/__init__.py +20 -0
  35. optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
  36. optimum/rbln/transformers/modeling_generic.py +385 -0
  37. optimum/rbln/transformers/models/auto/__init__.py +23 -0
  38. optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
  39. optimum/rbln/transformers/models/bart/__init__.py +0 -1
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
  42. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  43. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
  44. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
  45. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
  46. optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
  47. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  48. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
  49. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
  51. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
  52. optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
  53. optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
  54. optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
  55. optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
  56. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
  57. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
  58. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
  59. optimum/rbln/transformers/models/t5/__init__.py +0 -1
  60. optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
  61. optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
  62. optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
  63. optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
  64. optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
  65. optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
  66. optimum/rbln/utils/decorator_utils.py +51 -15
  67. optimum/rbln/utils/import_utils.py +8 -1
  68. optimum/rbln/utils/logging.py +38 -1
  69. optimum/rbln/utils/model_utils.py +0 -1
  70. optimum/rbln/utils/runtime_utils.py +9 -3
  71. optimum/rbln/utils/save_utils.py +17 -0
  72. optimum/rbln/utils/submodule.py +23 -0
  73. optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
  74. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
  75. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
  76. optimum/rbln/transformers/cache_utils.py +0 -107
  77. optimum/rbln/utils/timer_utils.py +0 -43
  78. optimum_rbln-0.1.15.dist-info/METADATA +0 -106
  79. optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
  80. {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -91,21 +91,36 @@ class RBLNCompileConfig:
91
91
  self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
92
92
  return self
93
93
 
94
- def get_dummy_inputs(self, fill=0):
94
+ def get_dummy_inputs(
95
+ self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
96
+ ):
95
97
  dummy = []
96
98
  for name, shape, dtype in self.input_info:
97
- dummy.append(
98
- torch.fill(torch.zeros(*shape, dtype=getattr(torch, dtype)), fill)
99
- if len(shape) > 0
100
- else torch.tensor(fill, dtype=getattr(torch, dtype))
101
- )
99
+ if name in static_tensors:
100
+ tensor = static_tensors[name]
101
+ if shape != list(tensor.shape):
102
+ raise RuntimeError(f"Different shape for dummy inputs. ({shape} != {list(tensor.shape)})")
103
+ if getattr(torch, dtype) != tensor.dtype:
104
+ raise RuntimeError(f"Different dtype for dummy inputs ({dtype} != {tensor.dtype})")
105
+ dummy.append(tensor)
106
+ else:
107
+ if name in meta_tensor_names:
108
+ device = "meta"
109
+ else:
110
+ device = "cpu"
111
+
112
+ dummy.append(
113
+ torch.fill(torch.empty(*shape, dtype=getattr(torch, dtype), device=torch.device(device)), fill)
114
+ if len(shape) > 0
115
+ else torch.tensor(fill, dtype=getattr(torch, dtype), device=torch.device(device))
116
+ )
102
117
  return tuple(dummy)
103
118
 
104
119
  def asdict(self):
105
120
  return asdict(self)
106
121
 
107
122
 
108
- RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map"]
123
+ RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
109
124
  COMPILE_KEYWORDS = ["compiled_model_name", "mod_name", "input_info", "fusion", "npu", "tensor_parallel_size"]
110
125
 
111
126
 
@@ -243,6 +258,15 @@ class RBLNConfig:
243
258
  return rbln_device_map
244
259
  return self.runtime_cfg["device_map"]
245
260
 
261
+ @property
262
+ def activate_profiler(self):
263
+ context = ContextRblnConfig.get_current_context()["activate_profiler"]
264
+ if context:
265
+ return context
266
+ elif self.runtime_cfg.get("activate_profiler", None) is None:
267
+ return False
268
+ return self.runtime_cfg["activate_profiler"]
269
+
246
270
 
247
271
  def use_rbln_config(fn):
248
272
  """
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .attn import register_rbln_custom_attention, register_rbln_custom_attention_add_softmax
25
+ from .flash_attn import register_rbln_custom_flash_attention
26
+ from .kv_cache_update import register_rbln_custom_cache_update
@@ -0,0 +1,221 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from functools import lru_cache
25
+
26
+ import torch
27
+ from packaging import version
28
+
29
+
30
+ if version.parse(torch.__version__) > version.parse("2.4.0"):
31
+ register_fake = torch.library.register_fake
32
+ else:
33
+ register_fake = torch.library.impl_abstract
34
+
35
+
36
+ @lru_cache
37
+ def register_rbln_custom_attention():
38
+ torch.library.define(
39
+ "rbln_custom_ops::attn_decode",
40
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
41
+ )
42
+
43
+ @torch.library.impl("rbln_custom_ops::attn_decode", "cpu")
44
+ def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
45
+ """Defines the computation pattern for fused attention with KV cache updates.
46
+
47
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
48
+ a single optimized NPU operation. It is NOT meant for CPU execution.
49
+
50
+ Pattern components that compiler fuses into a single op:
51
+ 1. KV cache updates with new key/value states
52
+ 2. Scaled dot-product attention computation
53
+ 3. Masked softmax operation
54
+ 4. Final attention output computation
55
+
56
+ Expected tensor shapes:
57
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
58
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
59
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
60
+ - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
61
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
62
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
63
+ - seq: [1] - Current sequence position
64
+ - scale: [] - Attention scale factor
65
+
66
+ Returns:
67
+ Tuple[Tensor, Tensor, Tensor]:
68
+ - attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
69
+ - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
70
+ - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
71
+ """
72
+ return (
73
+ q,
74
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
75
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
76
+ )
77
+
78
+ @register_fake("rbln_custom_ops::attn_decode")
79
+ def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
80
+ return (
81
+ q,
82
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
83
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
84
+ )
85
+
86
+ torch.library.define(
87
+ "rbln_custom_ops::attn_prefill",
88
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
89
+ )
90
+
91
+ @torch.library.impl("rbln_custom_ops::attn_prefill", "cpu")
92
+ def attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
93
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
94
+
95
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
96
+ a single optimized NPU operation. It is NOT meant for CPU execution.
97
+
98
+ Key differences from decode pattern:
99
+ - Handles prefill phase with multiple input tokens
100
+ - Takes explicit batch index for continuous batching
101
+
102
+ Expected tensor shapes:
103
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
104
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
105
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
106
+ - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
107
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
108
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
109
+ - batch: [1] - Batch index for cache access
110
+ - seq: [1] - Starting sequence position
111
+ - scale: [] - Attention scale factor
112
+
113
+ Returns:
114
+ Tuple[Tensor, Tensor, Tensor]:
115
+ - attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
116
+ - empty_kcache: Same shape as input kcache - Placeholder for compiler
117
+ - empty_vcache: Same shape as input vcache - Placeholder for compiler
118
+ """
119
+ return q, kcache, vcache
120
+
121
+ @register_fake("rbln_custom_ops::attn_prefill")
122
+ def attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
123
+ return q, kcache, vcache
124
+
125
+
126
+ @lru_cache
127
+ def register_rbln_custom_attention_add_softmax():
128
+ torch.library.define(
129
+ "rbln_custom_ops::attn_decode_add_softmax",
130
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
131
+ )
132
+
133
+ @torch.library.impl("rbln_custom_ops::attn_decode_add_softmax", "cpu")
134
+ def attn_decode_add_softmax_cpu(q, k, v, mask, kcache, vcache, seq, scale):
135
+ """Defines the computation pattern for fused attention with KV cache updates.
136
+
137
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
138
+ a single optimized NPU operation. It is NOT meant for CPU execution.
139
+
140
+ Pattern components that compiler fuses into a single op:
141
+ 1. KV cache updates with new key/value states
142
+ 2. Scaled dot-product attention computation
143
+ 3. add-softmax operation
144
+ 4. Final attention output computation
145
+
146
+ Expected tensor shapes:
147
+ - q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
148
+ - k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
149
+ - v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
150
+ - mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
151
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
152
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
153
+ - seq: [1] - Current sequence position
154
+ - scale: [] - Attention scale factor
155
+
156
+ Returns:
157
+ Tuple[Tensor, Tensor, Tensor]:
158
+ - attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
159
+ - kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
160
+ - vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
161
+ """
162
+ return (
163
+ q,
164
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
165
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
166
+ )
167
+
168
+ @register_fake("rbln_custom_ops::attn_decode_add_softmax")
169
+ def attn_decode_add_softmax_abstract(q, k, v, m, kcache, vcache, seq, partition):
170
+ return (
171
+ q,
172
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
173
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
174
+ )
175
+
176
+ torch.library.define(
177
+ "rbln_custom_ops::attn_prefill_add_softmax",
178
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
179
+ )
180
+
181
+ @torch.library.impl("rbln_custom_ops::attn_prefill_add_softmax", "cpu")
182
+ def attn_prefill_add_softmax_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
183
+ """Defines the computation pattern for prefill phase attention with KV cache updates.
184
+
185
+ IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
186
+ a single optimized NPU operation. It is NOT meant for CPU execution.
187
+
188
+ Key differences from decode pattern:
189
+ - Handles prefill phase with multiple input tokens
190
+ - Takes explicit batch index for continuous batching
191
+
192
+ Expected tensor shapes:
193
+ - q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
194
+ - k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
195
+ - v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
196
+ - mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
197
+ - kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
198
+ - vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
199
+ - batch: [1] - Batch index for cache access
200
+ - seq: [1] - Starting sequence position
201
+ - scale: [] - Attention scale factor
202
+
203
+ Returns:
204
+ Tuple[Tensor, Tensor, Tensor]:
205
+ - attn_output: [batch=1, n_heads, seq_len, 1, head_dim] - Attention output
206
+ - empty_kcache: Same shape as input kcache - Placeholder for compiler
207
+ - empty_vcache: Same shape as input vcache - Placeholder for compiler
208
+ """
209
+ return (
210
+ q,
211
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
212
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
213
+ )
214
+
215
+ @register_fake("rbln_custom_ops::attn_prefill_add_softmax")
216
+ def attn_prefill_add_softmax_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
217
+ return (
218
+ q,
219
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
220
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
221
+ )
@@ -0,0 +1,70 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from functools import lru_cache
25
+
26
+ import torch
27
+ from packaging import version
28
+
29
+
30
+ if version.parse(torch.__version__) > version.parse("2.4.0"):
31
+ register_fake = torch.library.register_fake
32
+ else:
33
+ register_fake = torch.library.impl_abstract
34
+
35
+
36
+ @lru_cache
37
+ def register_rbln_custom_flash_attention():
38
+ torch.library.define(
39
+ "rbln_custom_ops::flash_attn_decode",
40
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
41
+ )
42
+
43
+ @torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
44
+ def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
45
+ return (
46
+ q,
47
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
48
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
49
+ )
50
+
51
+ @register_fake("rbln_custom_ops::flash_attn_decode")
52
+ def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
53
+ return (
54
+ q,
55
+ torch.empty(1, *kcache.shape[1:], device=kcache.device),
56
+ torch.empty(1, *vcache.shape[1:], device=vcache.device),
57
+ )
58
+
59
+ torch.library.define(
60
+ "rbln_custom_ops::flash_attn_prefill",
61
+ "(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
62
+ )
63
+
64
+ @torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
65
+ def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale, partition):
66
+ return q, kcache, vcache
67
+
68
+ @register_fake("rbln_custom_ops::flash_attn_prefill")
69
+ def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, scale, partition):
70
+ return q, kcache, vcache
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from functools import lru_cache
25
+
26
+ import torch
27
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
28
+
29
+
30
+ if is_torch_greater_or_equal_than_2_4:
31
+ register_fake = torch.library.register_fake
32
+ else:
33
+ register_fake = torch.library.impl_abstract
34
+
35
+
36
+ @lru_cache
37
+ def register_rbln_custom_cache_update():
38
+ # Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
39
+ # This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
40
+ # The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
41
+ torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
42
+
43
+ # Implementation of the "rbln_cache_update" operation for the CPU.
44
+ @torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
45
+ def rbln_cache_update_cpu(cache, state, position, axis):
46
+ assert position.dim() == 0
47
+ assert axis.dim() == 0
48
+
49
+ # Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
50
+ s = position # Start index for the update, specified by the position.
51
+ e = (
52
+ position + state.shape[axis]
53
+ ) # End index is determined by adding the size of the state along the given axis.
54
+
55
+ # Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
56
+ # This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
57
+ updated_cache = cache.slice_scatter(state, dim=axis, start=s, end=e)
58
+
59
+ # Return the updated cache tensor.
60
+ return updated_cache
61
+
62
+ # Register a "fake" implementation of the "rbln_cache_update" operation.
63
+ # This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
64
+ @register_fake("rbln_custom_ops::rbln_cache_update")
65
+ def rbln_cache_update_abstract(cache, state, position, axis):
66
+ # Return a tensor with the same shape as the input cache tensor.
67
+ # This is a placeholder for the abstract implementation and does not perform any actual computation.
68
+ # Like the actual implementation, the abstraction assumes in-place device-side updates.
69
+ return torch.empty_like(cache)
@@ -63,10 +63,30 @@ _import_structure = {
63
63
  "RBLNXLMRobertaModel",
64
64
  "RBLNMistralForCausalLM",
65
65
  ],
66
+ "modeling_alias": [
67
+ "RBLNASTForAudioClassification",
68
+ "RBLNBertForQuestionAnswering",
69
+ "RBLNDistilBertForQuestionAnswering",
70
+ "RBLNResNetForImageClassification",
71
+ "RBLNXLMRobertaForSequenceClassification",
72
+ "RBLNRobertaForSequenceClassification",
73
+ "RBLNRobertaForMaskedLM",
74
+ "RBLNViTForImageClassification",
75
+ ],
66
76
  }
67
77
 
68
78
  if TYPE_CHECKING:
69
79
  from .cache_utils import RebelDynamicCache
80
+ from .modeling_alias import (
81
+ RBLNASTForAudioClassification,
82
+ RBLNBertForQuestionAnswering,
83
+ RBLNDistilBertForQuestionAnswering,
84
+ RBLNResNetForImageClassification,
85
+ RBLNRobertaForMaskedLM,
86
+ RBLNRobertaForSequenceClassification,
87
+ RBLNViTForImageClassification,
88
+ RBLNXLMRobertaForSequenceClassification,
89
+ )
70
90
  from .models import (
71
91
  RBLNAutoModel,
72
92
  RBLNAutoModelForAudioClassification,
@@ -21,7 +21,8 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling import (
24
+ from ..utils.logging import get_logger
25
+ from .modeling_generic import (
25
26
  RBLNModelForAudioClassification,
26
27
  RBLNModelForImageClassification,
27
28
  RBLNModelForMaskedLM,
@@ -30,6 +31,9 @@ from .modeling import (
30
31
  )
31
32
 
32
33
 
34
+ logger = get_logger()
35
+
36
+
33
37
  class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
34
38
  pass
35
39