optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +8 -1
- optimum/rbln/utils/logging.py +38 -1
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/METADATA +0 -106
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
optimum/rbln/modeling_config.py
CHANGED
@@ -91,21 +91,36 @@ class RBLNCompileConfig:
|
|
91
91
|
self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
|
92
92
|
return self
|
93
93
|
|
94
|
-
def get_dummy_inputs(
|
94
|
+
def get_dummy_inputs(
|
95
|
+
self, fill=0, static_tensors: Dict[str, torch.Tensor] = {}, meta_tensor_names: List[str] = []
|
96
|
+
):
|
95
97
|
dummy = []
|
96
98
|
for name, shape, dtype in self.input_info:
|
97
|
-
|
98
|
-
|
99
|
-
if
|
100
|
-
|
101
|
-
|
99
|
+
if name in static_tensors:
|
100
|
+
tensor = static_tensors[name]
|
101
|
+
if shape != list(tensor.shape):
|
102
|
+
raise RuntimeError(f"Different shape for dummy inputs. ({shape} != {list(tensor.shape)})")
|
103
|
+
if getattr(torch, dtype) != tensor.dtype:
|
104
|
+
raise RuntimeError(f"Different dtype for dummy inputs ({dtype} != {tensor.dtype})")
|
105
|
+
dummy.append(tensor)
|
106
|
+
else:
|
107
|
+
if name in meta_tensor_names:
|
108
|
+
device = "meta"
|
109
|
+
else:
|
110
|
+
device = "cpu"
|
111
|
+
|
112
|
+
dummy.append(
|
113
|
+
torch.fill(torch.empty(*shape, dtype=getattr(torch, dtype), device=torch.device(device)), fill)
|
114
|
+
if len(shape) > 0
|
115
|
+
else torch.tensor(fill, dtype=getattr(torch, dtype), device=torch.device(device))
|
116
|
+
)
|
102
117
|
return tuple(dummy)
|
103
118
|
|
104
119
|
def asdict(self):
|
105
120
|
return asdict(self)
|
106
121
|
|
107
122
|
|
108
|
-
RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map"]
|
123
|
+
RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map", "activate_profiler"]
|
109
124
|
COMPILE_KEYWORDS = ["compiled_model_name", "mod_name", "input_info", "fusion", "npu", "tensor_parallel_size"]
|
110
125
|
|
111
126
|
|
@@ -243,6 +258,15 @@ class RBLNConfig:
|
|
243
258
|
return rbln_device_map
|
244
259
|
return self.runtime_cfg["device_map"]
|
245
260
|
|
261
|
+
@property
|
262
|
+
def activate_profiler(self):
|
263
|
+
context = ContextRblnConfig.get_current_context()["activate_profiler"]
|
264
|
+
if context:
|
265
|
+
return context
|
266
|
+
elif self.runtime_cfg.get("activate_profiler", None) is None:
|
267
|
+
return False
|
268
|
+
return self.runtime_cfg["activate_profiler"]
|
269
|
+
|
246
270
|
|
247
271
|
def use_rbln_config(fn):
|
248
272
|
"""
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .attn import register_rbln_custom_attention, register_rbln_custom_attention_add_softmax
|
25
|
+
from .flash_attn import register_rbln_custom_flash_attention
|
26
|
+
from .kv_cache_update import register_rbln_custom_cache_update
|
optimum/rbln/ops/attn.py
ADDED
@@ -0,0 +1,221 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from functools import lru_cache
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from packaging import version
|
28
|
+
|
29
|
+
|
30
|
+
if version.parse(torch.__version__) > version.parse("2.4.0"):
|
31
|
+
register_fake = torch.library.register_fake
|
32
|
+
else:
|
33
|
+
register_fake = torch.library.impl_abstract
|
34
|
+
|
35
|
+
|
36
|
+
@lru_cache
|
37
|
+
def register_rbln_custom_attention():
|
38
|
+
torch.library.define(
|
39
|
+
"rbln_custom_ops::attn_decode",
|
40
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
41
|
+
)
|
42
|
+
|
43
|
+
@torch.library.impl("rbln_custom_ops::attn_decode", "cpu")
|
44
|
+
def attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale):
|
45
|
+
"""Defines the computation pattern for fused attention with KV cache updates.
|
46
|
+
|
47
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
48
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
49
|
+
|
50
|
+
Pattern components that compiler fuses into a single op:
|
51
|
+
1. KV cache updates with new key/value states
|
52
|
+
2. Scaled dot-product attention computation
|
53
|
+
3. Masked softmax operation
|
54
|
+
4. Final attention output computation
|
55
|
+
|
56
|
+
Expected tensor shapes:
|
57
|
+
- q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
|
58
|
+
- k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
|
59
|
+
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
60
|
+
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
61
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
62
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
63
|
+
- seq: [1] - Current sequence position
|
64
|
+
- scale: [] - Attention scale factor
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
Tuple[Tensor, Tensor, Tensor]:
|
68
|
+
- attn_output: [batch=1, n_heads, n_groups, 1, head_dim] - Attention output
|
69
|
+
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
70
|
+
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
71
|
+
"""
|
72
|
+
return (
|
73
|
+
q,
|
74
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
75
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
76
|
+
)
|
77
|
+
|
78
|
+
@register_fake("rbln_custom_ops::attn_decode")
|
79
|
+
def attn_decode_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
80
|
+
return (
|
81
|
+
q,
|
82
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
83
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
84
|
+
)
|
85
|
+
|
86
|
+
torch.library.define(
|
87
|
+
"rbln_custom_ops::attn_prefill",
|
88
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
89
|
+
)
|
90
|
+
|
91
|
+
@torch.library.impl("rbln_custom_ops::attn_prefill", "cpu")
|
92
|
+
def attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
|
93
|
+
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
94
|
+
|
95
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
96
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
97
|
+
|
98
|
+
Key differences from decode pattern:
|
99
|
+
- Handles prefill phase with multiple input tokens
|
100
|
+
- Takes explicit batch index for continuous batching
|
101
|
+
|
102
|
+
Expected tensor shapes:
|
103
|
+
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
104
|
+
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
105
|
+
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
106
|
+
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
107
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
108
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
109
|
+
- batch: [1] - Batch index for cache access
|
110
|
+
- seq: [1] - Starting sequence position
|
111
|
+
- scale: [] - Attention scale factor
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Tuple[Tensor, Tensor, Tensor]:
|
115
|
+
- attn_output: [batch=1, n_heads, n_groups, seq_len, head_dim] - Attention output
|
116
|
+
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
117
|
+
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
118
|
+
"""
|
119
|
+
return q, kcache, vcache
|
120
|
+
|
121
|
+
@register_fake("rbln_custom_ops::attn_prefill")
|
122
|
+
def attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
123
|
+
return q, kcache, vcache
|
124
|
+
|
125
|
+
|
126
|
+
@lru_cache
|
127
|
+
def register_rbln_custom_attention_add_softmax():
|
128
|
+
torch.library.define(
|
129
|
+
"rbln_custom_ops::attn_decode_add_softmax",
|
130
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d) -> Tensor[]",
|
131
|
+
)
|
132
|
+
|
133
|
+
@torch.library.impl("rbln_custom_ops::attn_decode_add_softmax", "cpu")
|
134
|
+
def attn_decode_add_softmax_cpu(q, k, v, mask, kcache, vcache, seq, scale):
|
135
|
+
"""Defines the computation pattern for fused attention with KV cache updates.
|
136
|
+
|
137
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
138
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
139
|
+
|
140
|
+
Pattern components that compiler fuses into a single op:
|
141
|
+
1. KV cache updates with new key/value states
|
142
|
+
2. Scaled dot-product attention computation
|
143
|
+
3. add-softmax operation
|
144
|
+
4. Final attention output computation
|
145
|
+
|
146
|
+
Expected tensor shapes:
|
147
|
+
- q: [batch=1, n_heads, n_groups, 1, head_dim] - Query states for single token
|
148
|
+
- k: [batch=1, n_heads, 1, 1, head_dim] - Key states for current input
|
149
|
+
- v: [batch=1, n_heads, 1, 1, head_dim] - Value states for current input
|
150
|
+
- mask: [batch=1, n_heads, 1, 1, max_seq_len] - Attention mask
|
151
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
152
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
153
|
+
- seq: [1] - Current sequence position
|
154
|
+
- scale: [] - Attention scale factor
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
Tuple[Tensor, Tensor, Tensor]:
|
158
|
+
- attn_output: [batch=1, n_heads, 1, 1, head_dim] - Attention output
|
159
|
+
- kcache: Same shape as input kcache, batch=1 - Placeholder for compiler
|
160
|
+
- vcache: Same shape as input vcache, batch=1 - Placeholder for compiler
|
161
|
+
"""
|
162
|
+
return (
|
163
|
+
q,
|
164
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
165
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
166
|
+
)
|
167
|
+
|
168
|
+
@register_fake("rbln_custom_ops::attn_decode_add_softmax")
|
169
|
+
def attn_decode_add_softmax_abstract(q, k, v, m, kcache, vcache, seq, partition):
|
170
|
+
return (
|
171
|
+
q,
|
172
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
173
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
174
|
+
)
|
175
|
+
|
176
|
+
torch.library.define(
|
177
|
+
"rbln_custom_ops::attn_prefill_add_softmax",
|
178
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e) -> Tensor[]",
|
179
|
+
)
|
180
|
+
|
181
|
+
@torch.library.impl("rbln_custom_ops::attn_prefill_add_softmax", "cpu")
|
182
|
+
def attn_prefill_add_softmax_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale):
|
183
|
+
"""Defines the computation pattern for prefill phase attention with KV cache updates.
|
184
|
+
|
185
|
+
IMPORTANT: This op serves as a pattern definition for the RBLN compiler to generate
|
186
|
+
a single optimized NPU operation. It is NOT meant for CPU execution.
|
187
|
+
|
188
|
+
Key differences from decode pattern:
|
189
|
+
- Handles prefill phase with multiple input tokens
|
190
|
+
- Takes explicit batch index for continuous batching
|
191
|
+
|
192
|
+
Expected tensor shapes:
|
193
|
+
- q: [batch=1, n_heads, n_groups, seq_len, head_dim] - Query states for multiple tokens
|
194
|
+
- k: [batch=1, n_heads, 1, seq_len, head_dim] - Key states for current input
|
195
|
+
- v: [batch=1, n_heads, 1, seq_len, head_dim] - Value states for current input
|
196
|
+
- mask: [batch=1, 1, 1, seq_len, max_seq_len] - Attention mask
|
197
|
+
- kcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Key cache
|
198
|
+
- vcache: [batch_size, n_heads, 1, max_seq_len, head_dim] - Value cache
|
199
|
+
- batch: [1] - Batch index for cache access
|
200
|
+
- seq: [1] - Starting sequence position
|
201
|
+
- scale: [] - Attention scale factor
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
Tuple[Tensor, Tensor, Tensor]:
|
205
|
+
- attn_output: [batch=1, n_heads, seq_len, 1, head_dim] - Attention output
|
206
|
+
- empty_kcache: Same shape as input kcache - Placeholder for compiler
|
207
|
+
- empty_vcache: Same shape as input vcache - Placeholder for compiler
|
208
|
+
"""
|
209
|
+
return (
|
210
|
+
q,
|
211
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
212
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
213
|
+
)
|
214
|
+
|
215
|
+
@register_fake("rbln_custom_ops::attn_prefill_add_softmax")
|
216
|
+
def attn_prefill_add_softmax_abstract(q, k, v, m, kcache, vcache, batch, seq, partition):
|
217
|
+
return (
|
218
|
+
q,
|
219
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
220
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
221
|
+
)
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from functools import lru_cache
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from packaging import version
|
28
|
+
|
29
|
+
|
30
|
+
if version.parse(torch.__version__) > version.parse("2.4.0"):
|
31
|
+
register_fake = torch.library.register_fake
|
32
|
+
else:
|
33
|
+
register_fake = torch.library.impl_abstract
|
34
|
+
|
35
|
+
|
36
|
+
@lru_cache
|
37
|
+
def register_rbln_custom_flash_attention():
|
38
|
+
torch.library.define(
|
39
|
+
"rbln_custom_ops::flash_attn_decode",
|
40
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, int e) -> Tensor[]",
|
41
|
+
)
|
42
|
+
|
43
|
+
@torch.library.impl("rbln_custom_ops::flash_attn_decode", "cpu")
|
44
|
+
def flash_attn_decode_cpu(q, k, v, mask, kcache, vcache, seq, scale, partition):
|
45
|
+
return (
|
46
|
+
q,
|
47
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
48
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
49
|
+
)
|
50
|
+
|
51
|
+
@register_fake("rbln_custom_ops::flash_attn_decode")
|
52
|
+
def flash_attn_decode_abstract(q, k, v, m, kcache, vcache, seq, scale, partition):
|
53
|
+
return (
|
54
|
+
q,
|
55
|
+
torch.empty(1, *kcache.shape[1:], device=kcache.device),
|
56
|
+
torch.empty(1, *vcache.shape[1:], device=vcache.device),
|
57
|
+
)
|
58
|
+
|
59
|
+
torch.library.define(
|
60
|
+
"rbln_custom_ops::flash_attn_prefill",
|
61
|
+
"(Tensor x, Tensor y, Tensor z, Tensor w, Tensor a, Tensor b, Tensor c, Tensor d, Tensor e, int f) -> Tensor[]",
|
62
|
+
)
|
63
|
+
|
64
|
+
@torch.library.impl("rbln_custom_ops::flash_attn_prefill", "cpu")
|
65
|
+
def flash_attn_prefill_cpu(q, k, v, mask, kcache, vcache, batch, seq, scale, partition):
|
66
|
+
return q, kcache, vcache
|
67
|
+
|
68
|
+
@register_fake("rbln_custom_ops::flash_attn_prefill")
|
69
|
+
def flash_attn_prefill_abstract(q, k, v, m, kcache, vcache, batch, seq, scale, partition):
|
70
|
+
return q, kcache, vcache
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from functools import lru_cache
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
28
|
+
|
29
|
+
|
30
|
+
if is_torch_greater_or_equal_than_2_4:
|
31
|
+
register_fake = torch.library.register_fake
|
32
|
+
else:
|
33
|
+
register_fake = torch.library.impl_abstract
|
34
|
+
|
35
|
+
|
36
|
+
@lru_cache
|
37
|
+
def register_rbln_custom_cache_update():
|
38
|
+
# Define the RBLN custom operation "rbln_cache_update" which updates a cache tensor with a given state tensor.
|
39
|
+
# This operation is designed to perform in-place updates directly on the device without needing to transfer the cache back to the host.
|
40
|
+
# The `position` parameter specifies the start index for the update along the specified axis, allowing flexible updates to any part of the cache tensor.
|
41
|
+
torch.library.define("rbln_custom_ops::rbln_cache_update", "(Tensor x, Tensor y, Tensor z, Tensor w) -> Tensor")
|
42
|
+
|
43
|
+
# Implementation of the "rbln_cache_update" operation for the CPU.
|
44
|
+
@torch.library.impl("rbln_custom_ops::rbln_cache_update", "cpu")
|
45
|
+
def rbln_cache_update_cpu(cache, state, position, axis):
|
46
|
+
assert position.dim() == 0
|
47
|
+
assert axis.dim() == 0
|
48
|
+
|
49
|
+
# Calculate the start (s) and end (e) indices for the update based on the position and the shape of the state tensor along the specified axis.
|
50
|
+
s = position # Start index for the update, specified by the position.
|
51
|
+
e = (
|
52
|
+
position + state.shape[axis]
|
53
|
+
) # End index is determined by adding the size of the state along the given axis.
|
54
|
+
|
55
|
+
# Update the specified portion of the cache tensor with the state tensor, using `slice_scatter`.
|
56
|
+
# This operation modifies the cache tensor in-place directly on the device, avoiding any unnecessary transfers between host and device.
|
57
|
+
updated_cache = cache.slice_scatter(state, dim=axis, start=s, end=e)
|
58
|
+
|
59
|
+
# Return the updated cache tensor.
|
60
|
+
return updated_cache
|
61
|
+
|
62
|
+
# Register a "fake" implementation of the "rbln_cache_update" operation.
|
63
|
+
# This serves as an abstract definition for the RBLN compiler to recognize the operation and generate an optimized implementation.
|
64
|
+
@register_fake("rbln_custom_ops::rbln_cache_update")
|
65
|
+
def rbln_cache_update_abstract(cache, state, position, axis):
|
66
|
+
# Return a tensor with the same shape as the input cache tensor.
|
67
|
+
# This is a placeholder for the abstract implementation and does not perform any actual computation.
|
68
|
+
# Like the actual implementation, the abstraction assumes in-place device-side updates.
|
69
|
+
return torch.empty_like(cache)
|
@@ -63,10 +63,30 @@ _import_structure = {
|
|
63
63
|
"RBLNXLMRobertaModel",
|
64
64
|
"RBLNMistralForCausalLM",
|
65
65
|
],
|
66
|
+
"modeling_alias": [
|
67
|
+
"RBLNASTForAudioClassification",
|
68
|
+
"RBLNBertForQuestionAnswering",
|
69
|
+
"RBLNDistilBertForQuestionAnswering",
|
70
|
+
"RBLNResNetForImageClassification",
|
71
|
+
"RBLNXLMRobertaForSequenceClassification",
|
72
|
+
"RBLNRobertaForSequenceClassification",
|
73
|
+
"RBLNRobertaForMaskedLM",
|
74
|
+
"RBLNViTForImageClassification",
|
75
|
+
],
|
66
76
|
}
|
67
77
|
|
68
78
|
if TYPE_CHECKING:
|
69
79
|
from .cache_utils import RebelDynamicCache
|
80
|
+
from .modeling_alias import (
|
81
|
+
RBLNASTForAudioClassification,
|
82
|
+
RBLNBertForQuestionAnswering,
|
83
|
+
RBLNDistilBertForQuestionAnswering,
|
84
|
+
RBLNResNetForImageClassification,
|
85
|
+
RBLNRobertaForMaskedLM,
|
86
|
+
RBLNRobertaForSequenceClassification,
|
87
|
+
RBLNViTForImageClassification,
|
88
|
+
RBLNXLMRobertaForSequenceClassification,
|
89
|
+
)
|
70
90
|
from .models import (
|
71
91
|
RBLNAutoModel,
|
72
92
|
RBLNAutoModelForAudioClassification,
|
@@ -21,7 +21,8 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from .
|
24
|
+
from ..utils.logging import get_logger
|
25
|
+
from .modeling_generic import (
|
25
26
|
RBLNModelForAudioClassification,
|
26
27
|
RBLNModelForImageClassification,
|
27
28
|
RBLNModelForMaskedLM,
|
@@ -30,6 +31,9 @@ from .modeling import (
|
|
30
31
|
)
|
31
32
|
|
32
33
|
|
34
|
+
logger = get_logger()
|
35
|
+
|
36
|
+
|
33
37
|
class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
|
34
38
|
pass
|
35
39
|
|