optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,9 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from dataclasses import asdict, dataclass
|
|
15
16
|
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
|
16
17
|
|
|
17
|
-
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....configuration_utils import RBLNModelConfig, RBLNSerializableConfigProtocol
|
|
18
19
|
from ....utils.logging import get_logger
|
|
19
20
|
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
20
21
|
from .configuration_lora import RBLNLoRAConfig
|
|
@@ -59,7 +60,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
59
60
|
phases: Optional[List[PhaseType]] = None,
|
|
60
61
|
logits_to_keep: Optional[int] = None,
|
|
61
62
|
output_hidden_states: Optional[bool] = None,
|
|
62
|
-
|
|
63
|
+
kvcache_metas: Optional[List["KVCacheMeta"]] = None,
|
|
64
|
+
**kwargs: Any,
|
|
63
65
|
):
|
|
64
66
|
"""
|
|
65
67
|
Args:
|
|
@@ -93,8 +95,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
93
95
|
processing input sequences. Defaults to 128. Must be a positive integer
|
|
94
96
|
divisible by 64. Affects prefill performance and memory usage.
|
|
95
97
|
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
|
96
|
-
PagedAttention KV cache
|
|
97
|
-
section below for details.
|
|
98
|
+
PagedAttention KV cache at compile time. Defaults to 0 (automatically determined).
|
|
99
|
+
See the "KV Cache Number of Blocks (`kvcache_num_blocks`)" section below for details.
|
|
98
100
|
decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
|
|
99
101
|
This allows the model to handle varying batch sizes efficiently during generation. If not specified,
|
|
100
102
|
defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
|
|
@@ -114,6 +116,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
114
116
|
logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
|
|
115
117
|
Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
|
|
116
118
|
output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
|
|
119
|
+
kvcache_metas (Optional[List["KVCacheMeta"]]): The metadata for the KV cache tensors. Handled internally if not provided. Defaults to None.
|
|
117
120
|
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
118
121
|
|
|
119
122
|
Raises:
|
|
@@ -152,17 +155,15 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
152
155
|
|
|
153
156
|
|
|
154
157
|
KV Cache Number of Blocks:
|
|
155
|
-
`kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache
|
|
156
|
-
Each block holds `kvcache_block_size` tokens of Key and Value states.
|
|
157
|
-
|
|
158
|
-
- **Automatic
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
- **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
|
|
165
|
-
but requires careful consideration of memory limits. Setting it too high may lead to
|
|
158
|
+
`kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache
|
|
159
|
+
at compile time. Each block holds `kvcache_block_size` tokens of Key and Value states.
|
|
160
|
+
|
|
161
|
+
- **Automatic Determination (Default)**: If `kvcache_num_blocks` is `0` (default), the number of blocks
|
|
162
|
+
is automatically determined during compilation to fit within the available DRAM on the NPU. This allows
|
|
163
|
+
the model to utilize the remaining memory after compilation without manual tuning, providing optimal
|
|
164
|
+
cache capacity for better performance with long sequences or larger batches.
|
|
165
|
+
- **Manual Setting**: You can explicitly set the number of blocks to a positive integer. This provides
|
|
166
|
+
finer control but requires careful consideration of memory limits. Setting it too high may lead to
|
|
166
167
|
compilation errors if it exceeds available memory. The system will issue warnings if your
|
|
167
168
|
setting exceeds the estimated maximum.
|
|
168
169
|
- **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
|
|
@@ -175,7 +176,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
175
176
|
are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
|
|
176
177
|
|
|
177
178
|
The optimal value depends on the specific model, task, hardware, and desired trade-off
|
|
178
|
-
between performance and memory usage.
|
|
179
|
+
between performance and memory usage. Automatic determination (default) provides a robust starting point
|
|
180
|
+
that adapts to the available DRAM on the NPU at compile time.
|
|
179
181
|
"""
|
|
180
182
|
|
|
181
183
|
super().__init__(**kwargs)
|
|
@@ -222,7 +224,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
222
224
|
if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
|
|
223
225
|
raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
|
|
224
226
|
|
|
225
|
-
self.kvcache_num_blocks = kvcache_num_blocks
|
|
227
|
+
self.kvcache_num_blocks = kvcache_num_blocks if kvcache_num_blocks is not None else 0
|
|
226
228
|
self.cache_impl = cache_impl or "static"
|
|
227
229
|
self.sliding_window = sliding_window
|
|
228
230
|
self.sliding_window_layers = sliding_window_layers or []
|
|
@@ -257,6 +259,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
257
259
|
# Larger batch size should be at the beginning of the list.
|
|
258
260
|
self.decoder_batch_sizes.sort(reverse=True)
|
|
259
261
|
|
|
262
|
+
self.kvcache_metas: List["KVCacheMeta"] = kvcache_metas or []
|
|
263
|
+
|
|
260
264
|
@staticmethod
|
|
261
265
|
def validate_phases_type(phases: List[PhaseType]):
|
|
262
266
|
if not isinstance(phases, list):
|
|
@@ -284,12 +288,52 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
284
288
|
def can_generate(self) -> bool:
|
|
285
289
|
return "decode" in self.phases
|
|
286
290
|
|
|
291
|
+
@property
|
|
292
|
+
def use_image_prefill(self):
|
|
293
|
+
return "image_prefill" in self.phases
|
|
294
|
+
|
|
295
|
+
@property
|
|
296
|
+
def image_prefill_runtime_idx(self):
|
|
297
|
+
return self.phases.index("image_prefill")
|
|
298
|
+
|
|
299
|
+
@property
|
|
300
|
+
def expected_compiled_model_names(self):
|
|
301
|
+
# ["prefill", "image_prefill", "decoder_batch_1", "decoder_batch_2", ...]
|
|
302
|
+
if self.can_generate:
|
|
303
|
+
return self.phases[: self.decoder_runtime_idx] + [
|
|
304
|
+
f"decoder_batch_{batch_size}" for batch_size in self.decoder_batch_sizes
|
|
305
|
+
]
|
|
306
|
+
else:
|
|
307
|
+
return self.phases
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def decoder_runtime_idx(self):
|
|
311
|
+
if self.can_generate:
|
|
312
|
+
return self.phases.index("decode")
|
|
313
|
+
else:
|
|
314
|
+
raise ValueError("`decode` phase is not in the phases.")
|
|
315
|
+
|
|
287
316
|
@property
|
|
288
317
|
def nbits_per_param(self) -> int:
|
|
289
318
|
if self.quantization:
|
|
290
319
|
return self.quantization.nbits_per_param
|
|
291
320
|
return 16
|
|
292
321
|
|
|
322
|
+
@property
|
|
323
|
+
def is_auto_num_blocks(self) -> bool:
|
|
324
|
+
"""Returns True if kvcache_num_blocks will be automatically determined during compilation to fit within the available DRAM on the NPU."""
|
|
325
|
+
return self.kvcache_num_blocks == 0
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def num_full_blocks(self) -> int:
|
|
329
|
+
return (self.max_seq_len // self.kvcache_block_size) * self.batch_size
|
|
330
|
+
|
|
331
|
+
@property
|
|
332
|
+
def num_min_blocks(self) -> int:
|
|
333
|
+
if self.attn_impl == "flash_attn":
|
|
334
|
+
return min(self.max_seq_len // self.kvcache_block_size + 1, self.num_full_blocks)
|
|
335
|
+
return self.batch_size
|
|
336
|
+
|
|
293
337
|
|
|
294
338
|
class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
|
|
295
339
|
"""
|
|
@@ -302,3 +346,86 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
|
|
|
302
346
|
|
|
303
347
|
_default_phases = ["prefill", "decode"]
|
|
304
348
|
_default_logits_to_keep = 1
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass
|
|
352
|
+
class KVCacheMeta(RBLNSerializableConfigProtocol):
|
|
353
|
+
"""
|
|
354
|
+
KVCacheMeta contains metadata describing the key-value (KV) cache tensor for a specific transformer layer.
|
|
355
|
+
|
|
356
|
+
This is used during compilation and runtime on RBLN devices to manage memory and configure the
|
|
357
|
+
static or dynamic characteristics of the cache implementation for decoder-only models.
|
|
358
|
+
|
|
359
|
+
Attributes:
|
|
360
|
+
name (str): Logical name of the KV cache tensor.
|
|
361
|
+
layer_index (int): Index of the transformer layer corresponding to this cache.
|
|
362
|
+
shape (list[int]): The 4D shape of the cache tensor:
|
|
363
|
+
[num_blocks, num_heads, block_size, head_dim]. The number of blocks may be dynamic or static
|
|
364
|
+
depending on model configuration.
|
|
365
|
+
layer_type (str): String describing the attention/cache algorithm (e.g., "full_attention", "sliding_attention").
|
|
366
|
+
is_auto (bool): Whether the number of blocks is automatically determined during compilation (True) or manually specified (False).
|
|
367
|
+
In both cases, the KV cache size is fixed at compile time.
|
|
368
|
+
dtype (str): Data type of the cache buffer ("float16", "float32", etc.).
|
|
369
|
+
"""
|
|
370
|
+
|
|
371
|
+
name: str
|
|
372
|
+
layer_index: int
|
|
373
|
+
shape: list[int] # (num_blocks, num_heads, block_size(seq), head_dim)
|
|
374
|
+
layer_type: str
|
|
375
|
+
is_auto: bool
|
|
376
|
+
dtype: str
|
|
377
|
+
|
|
378
|
+
def _prepare_for_serialization(self) -> dict[str, Any]:
|
|
379
|
+
return asdict(self)
|
|
380
|
+
|
|
381
|
+
@property
|
|
382
|
+
def compile_shape(self):
|
|
383
|
+
return [1, self.shape[1], self.shape[2], self.shape[3]] if self.can_resize else self.shape
|
|
384
|
+
|
|
385
|
+
@property
|
|
386
|
+
def can_resize(self):
|
|
387
|
+
return self.is_auto and self.layer_type == "full_attention"
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def num_blocks(self) -> int:
|
|
391
|
+
return self.shape[0]
|
|
392
|
+
|
|
393
|
+
@property
|
|
394
|
+
def block_size(self) -> int:
|
|
395
|
+
return self.shape[2]
|
|
396
|
+
|
|
397
|
+
@staticmethod
|
|
398
|
+
def make(
|
|
399
|
+
name: str,
|
|
400
|
+
layer_index: int,
|
|
401
|
+
num_key_value_heads: int,
|
|
402
|
+
head_dim: int,
|
|
403
|
+
dtype: str,
|
|
404
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
405
|
+
) -> "KVCacheMeta":
|
|
406
|
+
assert len(rbln_config.compile_cfgs) == 0, "KVCacheMeta cannot be created from rbln_config with compile_cfgs"
|
|
407
|
+
|
|
408
|
+
if rbln_config.sliding_window is not None and layer_index in rbln_config.sliding_window_layers:
|
|
409
|
+
layer_type = "sliding_attention"
|
|
410
|
+
block_size = rbln_config.sliding_window
|
|
411
|
+
num_blocks = rbln_config.batch_size
|
|
412
|
+
is_auto = False
|
|
413
|
+
|
|
414
|
+
else:
|
|
415
|
+
layer_type = "full_attention"
|
|
416
|
+
block_size = rbln_config.kvcache_block_size
|
|
417
|
+
|
|
418
|
+
if rbln_config.is_auto_num_blocks:
|
|
419
|
+
num_blocks = rbln_config.num_full_blocks
|
|
420
|
+
is_auto = True
|
|
421
|
+
else:
|
|
422
|
+
num_blocks = rbln_config.kvcache_num_blocks
|
|
423
|
+
is_auto = False
|
|
424
|
+
|
|
425
|
+
shape = [num_blocks, num_key_value_heads, block_size, head_dim]
|
|
426
|
+
if num_blocks <= 0:
|
|
427
|
+
raise ValueError("`num_blocks` must be greater than 0 when using KV cache.")
|
|
428
|
+
|
|
429
|
+
return KVCacheMeta(
|
|
430
|
+
name=name, layer_index=layer_index, shape=shape, layer_type=layer_type, is_auto=is_auto, dtype=dtype
|
|
431
|
+
)
|
|
@@ -46,7 +46,7 @@ class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
|
|
|
46
46
|
model = RBLNLlamaForCausalLM.from_pretrained(
|
|
47
47
|
model_id,
|
|
48
48
|
rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
|
|
49
|
-
|
|
49
|
+
dtype="auto",
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
|
|
@@ -75,7 +75,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
75
75
|
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
76
76
|
)
|
|
77
77
|
|
|
78
|
-
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
78
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len, use_rotary_emb)
|
|
79
79
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
|
|
80
80
|
self._phase = "prefill"
|
|
81
81
|
|
|
@@ -103,7 +103,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
103
103
|
def get_rbln_causal_lm_class(self):
|
|
104
104
|
return DecoderOnlyForCausalLM
|
|
105
105
|
|
|
106
|
-
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
106
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int, use_rotary_emb: bool):
|
|
107
107
|
new_layers = []
|
|
108
108
|
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
109
109
|
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
@@ -118,6 +118,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
118
118
|
new_layers,
|
|
119
119
|
self.rbln_config,
|
|
120
120
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
121
|
+
use_rotary_emb=use_rotary_emb,
|
|
121
122
|
)
|
|
122
123
|
|
|
123
124
|
if self.is_causal_lm:
|
|
@@ -144,8 +145,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
144
145
|
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
145
146
|
query_position = (
|
|
146
147
|
args.pop(0)
|
|
147
|
-
# query_position usage:
|
|
148
|
-
if (
|
|
148
|
+
# query_position usage: prefill & (logits_to_keep == 1 or use_local_attention)
|
|
149
|
+
if (
|
|
150
|
+
"prefill" in self.phase
|
|
151
|
+
and (self.rbln_config.logits_to_keep == 1 or self.rbln_config.use_local_attention)
|
|
152
|
+
)
|
|
149
153
|
else None
|
|
150
154
|
)
|
|
151
155
|
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
@@ -240,7 +244,6 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
240
244
|
|
|
241
245
|
Attributes:
|
|
242
246
|
config: Configuration from the original causal language model
|
|
243
|
-
_original_mod: Reference to the original model for components like lm_head
|
|
244
247
|
model: RBLN-optimized decoder model instance
|
|
245
248
|
_phase: Current processing phase ("prefill" or "decode")
|
|
246
249
|
"""
|
|
@@ -248,10 +251,9 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
248
251
|
def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
|
|
249
252
|
super().__init__()
|
|
250
253
|
self.config = causal_lm.config
|
|
251
|
-
self._original_mod = causal_lm
|
|
252
254
|
self.model = model
|
|
253
255
|
self._phase = "prefill"
|
|
254
|
-
self.lm_head =
|
|
256
|
+
self.lm_head = causal_lm.lm_head
|
|
255
257
|
|
|
256
258
|
@property
|
|
257
259
|
def phase(self):
|
|
@@ -293,7 +295,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
293
295
|
output_hidden_states=output_hidden_states,
|
|
294
296
|
)
|
|
295
297
|
|
|
296
|
-
if "prefill" in self.phase:
|
|
298
|
+
if "prefill" in self.phase and query_position is not None:
|
|
297
299
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
|
298
300
|
|
|
299
301
|
logits = self.lm_head(hidden_states)
|
|
@@ -317,20 +319,35 @@ class DecoderOnlyModel(nn.Module):
|
|
|
317
319
|
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
318
320
|
|
|
319
321
|
Attributes:
|
|
320
|
-
_original_mod: Reference to original Huggingface model
|
|
321
322
|
layers: ModuleList of RBLN-optimized transformer layers
|
|
322
323
|
_phase: Current processing phase ("prefill" or "decode")
|
|
323
324
|
"""
|
|
324
325
|
|
|
326
|
+
_EMBEDDING_ATTRS = ["embed_tokens", "wte"]
|
|
327
|
+
_POSITION_ATTRS = ["embed_positions", "wpe"]
|
|
328
|
+
_LAYERNORM_ATTRS = ["norm", "final_layer_norm", "final_layernorm", "ln_f", "layer_norm"]
|
|
329
|
+
_PRE_FF_LAYERNORM_ATTRS = None
|
|
330
|
+
_POST_FF_LAYERNORM_ATTRS = None
|
|
331
|
+
|
|
325
332
|
def __init__(
|
|
326
333
|
self,
|
|
327
334
|
model,
|
|
328
335
|
layers: List["DecoderOnlyLayer"],
|
|
329
336
|
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
330
337
|
use_learned_pos_emb=None,
|
|
338
|
+
use_rotary_emb=True,
|
|
331
339
|
):
|
|
332
340
|
super().__init__()
|
|
333
|
-
self.
|
|
341
|
+
self.config = model.config
|
|
342
|
+
# Keep commonly-used original submodules registered on this wrapper so their weights
|
|
343
|
+
# are preserved in state_dict even if the original model object is not kept.
|
|
344
|
+
# Different HF model families use different attribute names; we register what we can
|
|
345
|
+
# and allow subclasses to override getters when needed.
|
|
346
|
+
self.embed_tokens = _get_attr_from_candidates(model, self._EMBEDDING_ATTRS)
|
|
347
|
+
# hasattr(model, "rotary_emb") is workaround for Qwen2VL
|
|
348
|
+
if not (use_rotary_emb or hasattr(model, "rotary_emb")):
|
|
349
|
+
self.embed_positions = _get_attr_from_candidates(model, self._POSITION_ATTRS)
|
|
350
|
+
self.norm = _get_attr_from_candidates(model, self._LAYERNORM_ATTRS)
|
|
334
351
|
self.layers = nn.ModuleList(layers)
|
|
335
352
|
self.rbln_config = rbln_config
|
|
336
353
|
self._phase = "prefill"
|
|
@@ -369,26 +386,28 @@ class DecoderOnlyModel(nn.Module):
|
|
|
369
386
|
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
|
370
387
|
return cache_pos_for_partitions
|
|
371
388
|
|
|
372
|
-
def
|
|
373
|
-
max_cache_len = self.
|
|
389
|
+
def get_swa_custom_op_args(self, position_ids, query_position):
|
|
390
|
+
max_cache_len = self.config.sliding_window
|
|
374
391
|
valid_input_len = 1 if query_position is None else query_position + 1
|
|
375
|
-
cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
|
|
392
|
+
cache_seq_len = torch.clamp(position_ids.to(torch.int32), max=max_cache_len)[:, :1] # past seen tokens
|
|
376
393
|
cache_offset = (
|
|
377
394
|
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
|
|
378
395
|
) # cache offset for next steps
|
|
379
396
|
|
|
380
|
-
|
|
397
|
+
# Causal mask for sliding window attention
|
|
398
|
+
attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
|
|
399
|
+
attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, None, :]
|
|
400
|
+
|
|
401
|
+
return cache_seq_len, cache_offset, attn_mask
|
|
381
402
|
|
|
382
403
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
383
|
-
return self.
|
|
404
|
+
return self.norm
|
|
384
405
|
|
|
385
406
|
def get_embedding(self) -> nn.Embedding:
|
|
386
|
-
return self.
|
|
407
|
+
return self.embed_tokens
|
|
387
408
|
|
|
388
409
|
def get_pos_embedding(self) -> nn.Embedding:
|
|
389
|
-
|
|
390
|
-
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
|
391
|
-
)
|
|
410
|
+
return self.embed_positions
|
|
392
411
|
|
|
393
412
|
def forward(
|
|
394
413
|
self,
|
|
@@ -464,7 +483,8 @@ class DecoderOnlyModel(nn.Module):
|
|
|
464
483
|
|
|
465
484
|
# Get local cache positions for sliding window layers
|
|
466
485
|
if len(self.sliding_window_layers) > 0:
|
|
467
|
-
|
|
486
|
+
cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
|
|
487
|
+
sliding_cache_pos = (cache_seq_len, cache_offset)
|
|
468
488
|
|
|
469
489
|
all_hidden_states = () if output_hidden_states else None
|
|
470
490
|
for layer_idx, layer in enumerate(self.layers):
|
|
@@ -472,9 +492,10 @@ class DecoderOnlyModel(nn.Module):
|
|
|
472
492
|
all_hidden_states += (hidden_states,)
|
|
473
493
|
|
|
474
494
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
495
|
+
is_sliding_decode = is_sliding and self.phase == "decode"
|
|
475
496
|
hidden_states = layer(
|
|
476
497
|
hidden_states=hidden_states,
|
|
477
|
-
attention_mask=attention_mask,
|
|
498
|
+
attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
|
|
478
499
|
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
479
500
|
past_key_values=past_key_values,
|
|
480
501
|
cos=cos,
|
|
@@ -510,14 +531,24 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
510
531
|
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
|
511
532
|
|
|
512
533
|
Attributes:
|
|
513
|
-
_original_mod: Reference to original layer for accessing components
|
|
514
534
|
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
|
515
535
|
phase: Current operation phase ("prefill" or "decode")
|
|
516
536
|
"""
|
|
517
537
|
|
|
538
|
+
_PRE_ATTN_LAYERNORM = ["input_layernorm", "ln_1", "self_attn_layer_norm", "pre_feedforward_layernorm"]
|
|
539
|
+
_POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
|
|
540
|
+
_PRE_FF_LAYERNORM_ATTRS = None
|
|
541
|
+
_POST_FF_LAYERNORM_ATTRS = None
|
|
542
|
+
_MLP_ATTR = ("mlp",)
|
|
543
|
+
|
|
518
544
|
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
519
545
|
super().__init__()
|
|
520
|
-
|
|
546
|
+
|
|
547
|
+
self.pre_attention_layernorm = _get_attr_from_candidates(layer, self._PRE_ATTN_LAYERNORM)
|
|
548
|
+
self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
|
|
549
|
+
self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
|
|
550
|
+
self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
|
|
551
|
+
self.mlp = _get_attr_from_candidates(layer, self._MLP_ATTR)
|
|
521
552
|
self.self_attn = self_attn
|
|
522
553
|
self._phase = "prefill"
|
|
523
554
|
self.lora_config = lora_config
|
|
@@ -547,13 +578,19 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
547
578
|
self.self_attn.phase = phase
|
|
548
579
|
|
|
549
580
|
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
550
|
-
return self.
|
|
581
|
+
return self.pre_attention_layernorm
|
|
551
582
|
|
|
552
583
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
553
|
-
return self.
|
|
584
|
+
return self.post_attention_layernorm
|
|
585
|
+
|
|
586
|
+
def get_pre_feedforward_layernorm(self) -> nn.LayerNorm:
|
|
587
|
+
return self.pre_feedforward_layernorm
|
|
588
|
+
|
|
589
|
+
def get_post_feedforward_layernorm(self) -> nn.LayerNorm:
|
|
590
|
+
return self.post_feedforward_layernorm
|
|
554
591
|
|
|
555
592
|
def get_mlp(self) -> nn.Module:
|
|
556
|
-
return self.
|
|
593
|
+
return self.mlp
|
|
557
594
|
|
|
558
595
|
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
559
596
|
mlp = self.get_mlp()
|
|
@@ -619,6 +656,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
619
656
|
is_sliding: Whether this is sliding window attention
|
|
620
657
|
"""
|
|
621
658
|
|
|
659
|
+
_O_PROJ_ATTRS = ["o_proj", "out_proj", "dense"]
|
|
660
|
+
|
|
622
661
|
def __init__(
|
|
623
662
|
self,
|
|
624
663
|
self_attn,
|
|
@@ -626,20 +665,18 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
626
665
|
is_sliding=False,
|
|
627
666
|
):
|
|
628
667
|
super().__init__()
|
|
629
|
-
self.
|
|
668
|
+
self.config = getattr(self_attn, "config", None)
|
|
630
669
|
self.rbln_config = rbln_config
|
|
631
670
|
self.layer_idx = self_attn.layer_idx
|
|
632
|
-
self.num_heads = (
|
|
633
|
-
|
|
634
|
-
)
|
|
635
|
-
self.head_dim = self._original_mod.head_dim
|
|
671
|
+
self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
|
|
672
|
+
self.head_dim = self_attn.head_dim
|
|
636
673
|
self._phase = "prefill"
|
|
637
|
-
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
674
|
+
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale(self_attn)))
|
|
638
675
|
|
|
639
|
-
if hasattr(
|
|
640
|
-
self.num_key_value_heads =
|
|
641
|
-
elif hasattr(
|
|
642
|
-
self.num_key_value_heads =
|
|
676
|
+
if hasattr(self_attn, "num_key_value_heads"):
|
|
677
|
+
self.num_key_value_heads = self_attn.num_key_value_heads
|
|
678
|
+
elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
|
|
679
|
+
self.num_key_value_heads = self_attn.config.num_key_value_heads
|
|
643
680
|
else:
|
|
644
681
|
self.num_key_value_heads = self.num_heads
|
|
645
682
|
|
|
@@ -649,13 +686,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
649
686
|
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
650
687
|
self.lora_config = rbln_config.lora_config
|
|
651
688
|
|
|
689
|
+
if hasattr(self_attn, "sinks"):
|
|
690
|
+
self.sinks = self_attn.sinks.data[:, None]
|
|
691
|
+
|
|
652
692
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
653
|
-
self.__post_init__()
|
|
693
|
+
self.__post_init__(self_attn)
|
|
654
694
|
|
|
655
695
|
def _init_lora_weights(self):
|
|
656
696
|
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
657
697
|
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
658
|
-
original_linear = getattr(self
|
|
698
|
+
original_linear = getattr(self, proj_name)
|
|
659
699
|
lora_linear = LoRALinear(
|
|
660
700
|
original_linear=original_linear,
|
|
661
701
|
lora_config=self.lora_config,
|
|
@@ -712,16 +752,15 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
712
752
|
else:
|
|
713
753
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
714
754
|
|
|
715
|
-
def __post_init__(self):
|
|
755
|
+
def __post_init__(self, self_attn=None):
|
|
756
|
+
self.q_proj = self_attn.q_proj
|
|
757
|
+
self.k_proj = self_attn.k_proj
|
|
758
|
+
self.v_proj = self_attn.v_proj
|
|
759
|
+
self.o_proj = _get_attr_from_candidates(self_attn, self._O_PROJ_ATTRS)
|
|
760
|
+
|
|
716
761
|
# Initialize LoRA weights if configured, which will replace linear layers
|
|
717
762
|
if self.lora_config:
|
|
718
763
|
self._init_lora_weights()
|
|
719
|
-
else:
|
|
720
|
-
# Use original linear layers if no LoRA
|
|
721
|
-
self.q_proj = self._original_mod.q_proj
|
|
722
|
-
self.k_proj = self._original_mod.k_proj
|
|
723
|
-
self.v_proj = self._original_mod.v_proj
|
|
724
|
-
self.o_proj = self._original_mod.o_proj
|
|
725
764
|
|
|
726
765
|
def projection(
|
|
727
766
|
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
@@ -752,8 +791,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
752
791
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
753
792
|
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
754
793
|
|
|
755
|
-
def get_attn_scale(self):
|
|
756
|
-
return 1 / math.sqrt(
|
|
794
|
+
def get_attn_scale(self, self_attn):
|
|
795
|
+
return 1 / math.sqrt(self_attn.head_dim)
|
|
757
796
|
|
|
758
797
|
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
759
798
|
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
@@ -810,6 +849,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
810
849
|
block_size=self.kvcache_block_size,
|
|
811
850
|
k_scale=k_scale,
|
|
812
851
|
v_scale=v_scale,
|
|
852
|
+
s_aux=getattr(self, "sinks", None),
|
|
813
853
|
)
|
|
814
854
|
|
|
815
855
|
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
@@ -882,6 +922,7 @@ class AttentionOp(nn.Module):
|
|
|
882
922
|
block_size: int,
|
|
883
923
|
k_scale: Optional[torch.Tensor] = None,
|
|
884
924
|
v_scale: Optional[torch.Tensor] = None,
|
|
925
|
+
s_aux: Optional[torch.Tensor] = None,
|
|
885
926
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
886
927
|
"""Compute attention with static shapes and explicit cache management.
|
|
887
928
|
|
|
@@ -898,6 +939,7 @@ class AttentionOp(nn.Module):
|
|
|
898
939
|
block_size: Block size for paged attention
|
|
899
940
|
k_scale: Scale applied to key
|
|
900
941
|
v_scale: Scale applied to value
|
|
942
|
+
s_aux: Auxiliary states for attention sinks
|
|
901
943
|
|
|
902
944
|
Returns:
|
|
903
945
|
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
@@ -953,6 +995,9 @@ class AttentionOp(nn.Module):
|
|
|
953
995
|
op_args["k_scale"] = k_scale
|
|
954
996
|
op_args["v_scale"] = v_scale
|
|
955
997
|
|
|
998
|
+
if s_aux is not None:
|
|
999
|
+
op_args["s_aux"] = s_aux
|
|
1000
|
+
|
|
956
1001
|
attn_op_name = self.get_attn_op_name()
|
|
957
1002
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
958
1003
|
if attn_op is None:
|
|
@@ -1017,6 +1062,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1017
1062
|
block_size,
|
|
1018
1063
|
k_scale=None,
|
|
1019
1064
|
v_scale=None,
|
|
1065
|
+
s_aux=None,
|
|
1020
1066
|
):
|
|
1021
1067
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1022
1068
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1070,6 +1116,9 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1070
1116
|
op_args["k_scale"] = k_scale
|
|
1071
1117
|
op_args["v_scale"] = v_scale
|
|
1072
1118
|
|
|
1119
|
+
if s_aux is not None:
|
|
1120
|
+
op_args["s_aux"] = s_aux
|
|
1121
|
+
|
|
1073
1122
|
attn_op_name = self.get_attn_op_name()
|
|
1074
1123
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1075
1124
|
if attn_op is None:
|
|
@@ -1122,6 +1171,7 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1122
1171
|
block_size: int,
|
|
1123
1172
|
k_scale: Optional[torch.Tensor] = None,
|
|
1124
1173
|
v_scale: Optional[torch.Tensor] = None,
|
|
1174
|
+
s_aux: Optional[torch.Tensor] = None,
|
|
1125
1175
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1126
1176
|
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1127
1177
|
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
@@ -1165,6 +1215,11 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1165
1215
|
op_args["is_bidirectional"] = True
|
|
1166
1216
|
else:
|
|
1167
1217
|
op_args["is_bidirectional"] = False
|
|
1218
|
+
elif self.phase == "decode":
|
|
1219
|
+
op_args["attn_mask"] = attn_mask
|
|
1220
|
+
|
|
1221
|
+
if s_aux is not None:
|
|
1222
|
+
op_args["s_aux"] = s_aux
|
|
1168
1223
|
|
|
1169
1224
|
attn_op_name = self.get_attn_op_name()
|
|
1170
1225
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
@@ -1194,7 +1249,7 @@ class RotaryEmbedding(nn.Module):
|
|
|
1194
1249
|
else:
|
|
1195
1250
|
rope_type = "default"
|
|
1196
1251
|
|
|
1197
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1252
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, "cpu", max_seq_len_cached)
|
|
1198
1253
|
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1199
1254
|
cache_position_expanded = cache_position[:, None]
|
|
1200
1255
|
|
|
@@ -1271,3 +1326,22 @@ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tu
|
|
|
1271
1326
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1272
1327
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1273
1328
|
return query_states, key_states
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
def _get_attr_from_candidates(
|
|
1332
|
+
src: object,
|
|
1333
|
+
candidates: Optional[List[str]] = None,
|
|
1334
|
+
):
|
|
1335
|
+
"""
|
|
1336
|
+
Get an attribute from a list of candidate names.
|
|
1337
|
+
|
|
1338
|
+
- If `candidates` is None, this attribute is treated as optional and returns None.
|
|
1339
|
+
- Otherwise, returns `getattr(src, name)` for the first `name` in `candidates` that exists on `src`.
|
|
1340
|
+
- Raises AttributeError if `candidates` is provided but none of the names exist on `src`.
|
|
1341
|
+
"""
|
|
1342
|
+
if candidates is None:
|
|
1343
|
+
return None
|
|
1344
|
+
for name in candidates:
|
|
1345
|
+
if hasattr(src, name):
|
|
1346
|
+
return getattr(src, name)
|
|
1347
|
+
raise AttributeError(f"None of the attributes {candidates} exist in {src}")
|