optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,9 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from dataclasses import asdict, dataclass
|
|
15
16
|
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
|
16
17
|
|
|
17
|
-
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....configuration_utils import RBLNModelConfig, RBLNSerializableConfigProtocol
|
|
18
19
|
from ....utils.logging import get_logger
|
|
19
20
|
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
20
21
|
from .configuration_lora import RBLNLoRAConfig
|
|
@@ -59,6 +60,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
59
60
|
phases: Optional[List[PhaseType]] = None,
|
|
60
61
|
logits_to_keep: Optional[int] = None,
|
|
61
62
|
output_hidden_states: Optional[bool] = None,
|
|
63
|
+
kvcache_metas: Optional[List["KVCacheMeta"]] = None,
|
|
62
64
|
**kwargs,
|
|
63
65
|
):
|
|
64
66
|
"""
|
|
@@ -93,8 +95,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
93
95
|
processing input sequences. Defaults to 128. Must be a positive integer
|
|
94
96
|
divisible by 64. Affects prefill performance and memory usage.
|
|
95
97
|
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
|
96
|
-
PagedAttention KV cache
|
|
97
|
-
section below for details.
|
|
98
|
+
PagedAttention KV cache at compile time. Defaults to 0 (automatically determined).
|
|
99
|
+
See the "KV Cache Number of Blocks (`kvcache_num_blocks`)" section below for details.
|
|
98
100
|
decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
|
|
99
101
|
This allows the model to handle varying batch sizes efficiently during generation. If not specified,
|
|
100
102
|
defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
|
|
@@ -114,6 +116,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
114
116
|
logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
|
|
115
117
|
Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
|
|
116
118
|
output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
|
|
119
|
+
kvcache_metas (Optional[List["KVCacheMeta"]]): The metadata for the KV cache tensors. Handled internally if not provided. Defaults to None.
|
|
117
120
|
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
118
121
|
|
|
119
122
|
Raises:
|
|
@@ -152,17 +155,15 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
152
155
|
|
|
153
156
|
|
|
154
157
|
KV Cache Number of Blocks:
|
|
155
|
-
`kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache
|
|
156
|
-
Each block holds `kvcache_block_size` tokens of Key and Value states.
|
|
157
|
-
|
|
158
|
-
- **Automatic
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
- **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
|
|
165
|
-
but requires careful consideration of memory limits. Setting it too high may lead to
|
|
158
|
+
`kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache
|
|
159
|
+
at compile time. Each block holds `kvcache_block_size` tokens of Key and Value states.
|
|
160
|
+
|
|
161
|
+
- **Automatic Determination (Default)**: If `kvcache_num_blocks` is `0` (default), the number of blocks
|
|
162
|
+
is automatically determined during compilation to fit within the available DRAM on the NPU. This allows
|
|
163
|
+
the model to utilize the remaining memory after compilation without manual tuning, providing optimal
|
|
164
|
+
cache capacity for better performance with long sequences or larger batches.
|
|
165
|
+
- **Manual Setting**: You can explicitly set the number of blocks to a positive integer. This provides
|
|
166
|
+
finer control but requires careful consideration of memory limits. Setting it too high may lead to
|
|
166
167
|
compilation errors if it exceeds available memory. The system will issue warnings if your
|
|
167
168
|
setting exceeds the estimated maximum.
|
|
168
169
|
- **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
|
|
@@ -175,7 +176,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
175
176
|
are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
|
|
176
177
|
|
|
177
178
|
The optimal value depends on the specific model, task, hardware, and desired trade-off
|
|
178
|
-
between performance and memory usage.
|
|
179
|
+
between performance and memory usage. Automatic determination (default) provides a robust starting point
|
|
180
|
+
that adapts to the available DRAM on the NPU at compile time.
|
|
179
181
|
"""
|
|
180
182
|
|
|
181
183
|
super().__init__(**kwargs)
|
|
@@ -222,7 +224,7 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
222
224
|
if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
|
|
223
225
|
raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
|
|
224
226
|
|
|
225
|
-
self.kvcache_num_blocks = kvcache_num_blocks
|
|
227
|
+
self.kvcache_num_blocks = kvcache_num_blocks if kvcache_num_blocks is not None else 0
|
|
226
228
|
self.cache_impl = cache_impl or "static"
|
|
227
229
|
self.sliding_window = sliding_window
|
|
228
230
|
self.sliding_window_layers = sliding_window_layers or []
|
|
@@ -257,6 +259,8 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
257
259
|
# Larger batch size should be at the beginning of the list.
|
|
258
260
|
self.decoder_batch_sizes.sort(reverse=True)
|
|
259
261
|
|
|
262
|
+
self.kvcache_metas: List["KVCacheMeta"] = kvcache_metas or []
|
|
263
|
+
|
|
260
264
|
@staticmethod
|
|
261
265
|
def validate_phases_type(phases: List[PhaseType]):
|
|
262
266
|
if not isinstance(phases, list):
|
|
@@ -290,6 +294,21 @@ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
|
290
294
|
return self.quantization.nbits_per_param
|
|
291
295
|
return 16
|
|
292
296
|
|
|
297
|
+
@property
|
|
298
|
+
def is_auto_num_blocks(self) -> bool:
|
|
299
|
+
"""Returns True if kvcache_num_blocks will be automatically determined during compilation to fit within the available DRAM on the NPU."""
|
|
300
|
+
return self.kvcache_num_blocks == 0
|
|
301
|
+
|
|
302
|
+
@property
|
|
303
|
+
def num_full_blocks(self) -> int:
|
|
304
|
+
return (self.max_seq_len // self.kvcache_block_size) * self.batch_size
|
|
305
|
+
|
|
306
|
+
@property
|
|
307
|
+
def num_min_blocks(self) -> int:
|
|
308
|
+
if self.attn_impl == "flash_attn":
|
|
309
|
+
return min(self.max_seq_len // self.kvcache_block_size + 1, self.num_full_blocks)
|
|
310
|
+
return self.batch_size
|
|
311
|
+
|
|
293
312
|
|
|
294
313
|
class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
|
|
295
314
|
"""
|
|
@@ -302,3 +321,86 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
|
|
|
302
321
|
|
|
303
322
|
_default_phases = ["prefill", "decode"]
|
|
304
323
|
_default_logits_to_keep = 1
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
@dataclass
|
|
327
|
+
class KVCacheMeta(RBLNSerializableConfigProtocol):
|
|
328
|
+
"""
|
|
329
|
+
KVCacheMeta contains metadata describing the key-value (KV) cache tensor for a specific transformer layer.
|
|
330
|
+
|
|
331
|
+
This is used during compilation and runtime on RBLN devices to manage memory and configure the
|
|
332
|
+
static or dynamic characteristics of the cache implementation for decoder-only models.
|
|
333
|
+
|
|
334
|
+
Attributes:
|
|
335
|
+
name (str): Logical name of the KV cache tensor.
|
|
336
|
+
layer_index (int): Index of the transformer layer corresponding to this cache.
|
|
337
|
+
shape (list[int]): The 4D shape of the cache tensor:
|
|
338
|
+
[num_blocks, num_heads, block_size, head_dim]. The number of blocks may be dynamic or static
|
|
339
|
+
depending on model configuration.
|
|
340
|
+
layer_type (str): String describing the attention/cache algorithm (e.g., "full_attention", "sliding_attention").
|
|
341
|
+
is_auto (bool): Whether the number of blocks is automatically determined during compilation (True) or manually specified (False).
|
|
342
|
+
In both cases, the KV cache size is fixed at compile time.
|
|
343
|
+
dtype (str): Data type of the cache buffer ("float16", "float32", etc.).
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
name: str
|
|
347
|
+
layer_index: int
|
|
348
|
+
shape: list[int] # (num_blocks, num_heads, block_size(seq), head_dim)
|
|
349
|
+
layer_type: str
|
|
350
|
+
is_auto: bool
|
|
351
|
+
dtype: str
|
|
352
|
+
|
|
353
|
+
def _prepare_for_serialization(self) -> dict[str, Any]:
|
|
354
|
+
return asdict(self)
|
|
355
|
+
|
|
356
|
+
@property
|
|
357
|
+
def compile_shape(self):
|
|
358
|
+
return [1, self.shape[1], self.shape[2], self.shape[3]] if self.can_resize else self.shape
|
|
359
|
+
|
|
360
|
+
@property
|
|
361
|
+
def can_resize(self):
|
|
362
|
+
return self.is_auto and self.layer_type == "full_attention"
|
|
363
|
+
|
|
364
|
+
@property
|
|
365
|
+
def num_blocks(self) -> int:
|
|
366
|
+
return self.shape[0]
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def block_size(self) -> int:
|
|
370
|
+
return self.shape[2]
|
|
371
|
+
|
|
372
|
+
@staticmethod
|
|
373
|
+
def make(
|
|
374
|
+
name: str,
|
|
375
|
+
layer_index: int,
|
|
376
|
+
num_key_value_heads: int,
|
|
377
|
+
head_dim: int,
|
|
378
|
+
dtype: str,
|
|
379
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
380
|
+
) -> "KVCacheMeta":
|
|
381
|
+
assert len(rbln_config.compile_cfgs) == 0, "KVCacheMeta cannot be created from rbln_config with compile_cfgs"
|
|
382
|
+
|
|
383
|
+
if rbln_config.sliding_window is not None and layer_index in rbln_config.sliding_window_layers:
|
|
384
|
+
layer_type = "sliding_attention"
|
|
385
|
+
block_size = rbln_config.sliding_window
|
|
386
|
+
num_blocks = rbln_config.batch_size
|
|
387
|
+
is_auto = False
|
|
388
|
+
|
|
389
|
+
else:
|
|
390
|
+
layer_type = "full_attention"
|
|
391
|
+
block_size = rbln_config.kvcache_block_size
|
|
392
|
+
|
|
393
|
+
if rbln_config.is_auto_num_blocks:
|
|
394
|
+
num_blocks = rbln_config.num_full_blocks
|
|
395
|
+
is_auto = True
|
|
396
|
+
else:
|
|
397
|
+
num_blocks = rbln_config.kvcache_num_blocks
|
|
398
|
+
is_auto = False
|
|
399
|
+
|
|
400
|
+
shape = [num_blocks, num_key_value_heads, block_size, head_dim]
|
|
401
|
+
if num_blocks <= 0:
|
|
402
|
+
raise ValueError("`num_blocks` must be greater than 0 when using KV cache.")
|
|
403
|
+
|
|
404
|
+
return KVCacheMeta(
|
|
405
|
+
name=name, layer_index=layer_index, shape=shape, layer_type=layer_type, is_auto=is_auto, dtype=dtype
|
|
406
|
+
)
|
|
@@ -46,7 +46,7 @@ class RBLNLoRAAdapterConfig(RBLNSerializableConfigProtocol):
|
|
|
46
46
|
model = RBLNLlamaForCausalLM.from_pretrained(
|
|
47
47
|
model_id,
|
|
48
48
|
rbln_config=RBLNLlamaForCausalLMConfig(lora_config=lora_config, tensor_parallel_size=tp_size, max_seq_len=8192),
|
|
49
|
-
|
|
49
|
+
dtype="auto",
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
|
|
@@ -75,7 +75,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
75
75
|
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
76
76
|
)
|
|
77
77
|
|
|
78
|
-
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
78
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len, use_rotary_emb)
|
|
79
79
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
|
|
80
80
|
self._phase = "prefill"
|
|
81
81
|
|
|
@@ -103,7 +103,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
103
103
|
def get_rbln_causal_lm_class(self):
|
|
104
104
|
return DecoderOnlyForCausalLM
|
|
105
105
|
|
|
106
|
-
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
106
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int, use_rotary_emb: bool):
|
|
107
107
|
new_layers = []
|
|
108
108
|
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
109
109
|
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
@@ -118,6 +118,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
118
118
|
new_layers,
|
|
119
119
|
self.rbln_config,
|
|
120
120
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
121
|
+
use_rotary_emb=use_rotary_emb,
|
|
121
122
|
)
|
|
122
123
|
|
|
123
124
|
if self.is_causal_lm:
|
|
@@ -144,8 +145,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
144
145
|
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
145
146
|
query_position = (
|
|
146
147
|
args.pop(0)
|
|
147
|
-
# query_position usage:
|
|
148
|
-
if (
|
|
148
|
+
# query_position usage: prefill & (logits_to_keep == 1 or use_local_attention)
|
|
149
|
+
if (
|
|
150
|
+
"prefill" in self.phase
|
|
151
|
+
and (self.rbln_config.logits_to_keep == 1 or self.rbln_config.use_local_attention)
|
|
152
|
+
)
|
|
149
153
|
else None
|
|
150
154
|
)
|
|
151
155
|
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
@@ -240,7 +244,6 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
240
244
|
|
|
241
245
|
Attributes:
|
|
242
246
|
config: Configuration from the original causal language model
|
|
243
|
-
_original_mod: Reference to the original model for components like lm_head
|
|
244
247
|
model: RBLN-optimized decoder model instance
|
|
245
248
|
_phase: Current processing phase ("prefill" or "decode")
|
|
246
249
|
"""
|
|
@@ -248,10 +251,9 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
248
251
|
def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
|
|
249
252
|
super().__init__()
|
|
250
253
|
self.config = causal_lm.config
|
|
251
|
-
self._original_mod = causal_lm
|
|
252
254
|
self.model = model
|
|
253
255
|
self._phase = "prefill"
|
|
254
|
-
self.lm_head =
|
|
256
|
+
self.lm_head = causal_lm.lm_head
|
|
255
257
|
|
|
256
258
|
@property
|
|
257
259
|
def phase(self):
|
|
@@ -293,7 +295,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
293
295
|
output_hidden_states=output_hidden_states,
|
|
294
296
|
)
|
|
295
297
|
|
|
296
|
-
if "prefill" in self.phase:
|
|
298
|
+
if "prefill" in self.phase and query_position is not None:
|
|
297
299
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
|
298
300
|
|
|
299
301
|
logits = self.lm_head(hidden_states)
|
|
@@ -317,20 +319,35 @@ class DecoderOnlyModel(nn.Module):
|
|
|
317
319
|
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
318
320
|
|
|
319
321
|
Attributes:
|
|
320
|
-
_original_mod: Reference to original Huggingface model
|
|
321
322
|
layers: ModuleList of RBLN-optimized transformer layers
|
|
322
323
|
_phase: Current processing phase ("prefill" or "decode")
|
|
323
324
|
"""
|
|
324
325
|
|
|
326
|
+
_EMBEDDING_ATTRS = ["embed_tokens", "wte"]
|
|
327
|
+
_POSITION_ATTRS = ["embed_positions", "wpe"]
|
|
328
|
+
_LAYERNORM_ATTRS = ["norm", "final_layer_norm", "final_layernorm", "ln_f", "layer_norm"]
|
|
329
|
+
_PRE_FF_LAYERNORM_ATTRS = None
|
|
330
|
+
_POST_FF_LAYERNORM_ATTRS = None
|
|
331
|
+
|
|
325
332
|
def __init__(
|
|
326
333
|
self,
|
|
327
334
|
model,
|
|
328
335
|
layers: List["DecoderOnlyLayer"],
|
|
329
336
|
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
330
337
|
use_learned_pos_emb=None,
|
|
338
|
+
use_rotary_emb=True,
|
|
331
339
|
):
|
|
332
340
|
super().__init__()
|
|
333
|
-
self.
|
|
341
|
+
self.config = model.config
|
|
342
|
+
# Keep commonly-used original submodules registered on this wrapper so their weights
|
|
343
|
+
# are preserved in state_dict even if the original model object is not kept.
|
|
344
|
+
# Different HF model families use different attribute names; we register what we can
|
|
345
|
+
# and allow subclasses to override getters when needed.
|
|
346
|
+
self.embed_tokens = _get_attr_from_candidates(model, self._EMBEDDING_ATTRS)
|
|
347
|
+
# hasattr(model, "rotary_emb") is workaround for Qwen2VL
|
|
348
|
+
if not (use_rotary_emb or hasattr(model, "rotary_emb")):
|
|
349
|
+
self.embed_positions = _get_attr_from_candidates(model, self._POSITION_ATTRS)
|
|
350
|
+
self.norm = _get_attr_from_candidates(model, self._LAYERNORM_ATTRS)
|
|
334
351
|
self.layers = nn.ModuleList(layers)
|
|
335
352
|
self.rbln_config = rbln_config
|
|
336
353
|
self._phase = "prefill"
|
|
@@ -369,26 +386,28 @@ class DecoderOnlyModel(nn.Module):
|
|
|
369
386
|
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
|
370
387
|
return cache_pos_for_partitions
|
|
371
388
|
|
|
372
|
-
def
|
|
373
|
-
max_cache_len = self.
|
|
389
|
+
def get_swa_custom_op_args(self, position_ids, query_position):
|
|
390
|
+
max_cache_len = self.config.sliding_window
|
|
374
391
|
valid_input_len = 1 if query_position is None else query_position + 1
|
|
375
|
-
cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
|
|
392
|
+
cache_seq_len = torch.clamp(position_ids.to(torch.int32), max=max_cache_len)[:, :1] # past seen tokens
|
|
376
393
|
cache_offset = (
|
|
377
394
|
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
|
|
378
395
|
) # cache offset for next steps
|
|
379
396
|
|
|
380
|
-
|
|
397
|
+
# Causal mask for sliding window attention
|
|
398
|
+
attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
|
|
399
|
+
attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, None, :]
|
|
400
|
+
|
|
401
|
+
return cache_seq_len, cache_offset, attn_mask
|
|
381
402
|
|
|
382
403
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
383
|
-
return self.
|
|
404
|
+
return self.norm
|
|
384
405
|
|
|
385
406
|
def get_embedding(self) -> nn.Embedding:
|
|
386
|
-
return self.
|
|
407
|
+
return self.embed_tokens
|
|
387
408
|
|
|
388
409
|
def get_pos_embedding(self) -> nn.Embedding:
|
|
389
|
-
|
|
390
|
-
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
|
391
|
-
)
|
|
410
|
+
return self.embed_positions
|
|
392
411
|
|
|
393
412
|
def forward(
|
|
394
413
|
self,
|
|
@@ -464,7 +483,8 @@ class DecoderOnlyModel(nn.Module):
|
|
|
464
483
|
|
|
465
484
|
# Get local cache positions for sliding window layers
|
|
466
485
|
if len(self.sliding_window_layers) > 0:
|
|
467
|
-
|
|
486
|
+
cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
|
|
487
|
+
sliding_cache_pos = (cache_seq_len, cache_offset)
|
|
468
488
|
|
|
469
489
|
all_hidden_states = () if output_hidden_states else None
|
|
470
490
|
for layer_idx, layer in enumerate(self.layers):
|
|
@@ -472,9 +492,10 @@ class DecoderOnlyModel(nn.Module):
|
|
|
472
492
|
all_hidden_states += (hidden_states,)
|
|
473
493
|
|
|
474
494
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
495
|
+
is_sliding_decode = is_sliding and self.phase == "decode"
|
|
475
496
|
hidden_states = layer(
|
|
476
497
|
hidden_states=hidden_states,
|
|
477
|
-
attention_mask=attention_mask,
|
|
498
|
+
attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
|
|
478
499
|
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
479
500
|
past_key_values=past_key_values,
|
|
480
501
|
cos=cos,
|
|
@@ -510,14 +531,23 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
510
531
|
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
|
511
532
|
|
|
512
533
|
Attributes:
|
|
513
|
-
_original_mod: Reference to original layer for accessing components
|
|
514
534
|
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
|
515
535
|
phase: Current operation phase ("prefill" or "decode")
|
|
516
536
|
"""
|
|
517
537
|
|
|
538
|
+
_PRE_ATTN_LAYERNORM = ["input_layernorm", "ln_1", "self_attn_layer_norm", "pre_feedforward_layernorm"]
|
|
539
|
+
_POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
|
|
540
|
+
_PRE_FF_LAYERNORM_ATTRS = None
|
|
541
|
+
_POST_FF_LAYERNORM_ATTRS = None
|
|
542
|
+
|
|
518
543
|
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
519
544
|
super().__init__()
|
|
520
|
-
|
|
545
|
+
|
|
546
|
+
self.pre_attention_layernorm = _get_attr_from_candidates(layer, self._PRE_ATTN_LAYERNORM)
|
|
547
|
+
self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
|
|
548
|
+
self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
|
|
549
|
+
self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
|
|
550
|
+
self.mlp = layer.mlp
|
|
521
551
|
self.self_attn = self_attn
|
|
522
552
|
self._phase = "prefill"
|
|
523
553
|
self.lora_config = lora_config
|
|
@@ -547,13 +577,19 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
547
577
|
self.self_attn.phase = phase
|
|
548
578
|
|
|
549
579
|
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
550
|
-
return self.
|
|
580
|
+
return self.pre_attention_layernorm
|
|
551
581
|
|
|
552
582
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
553
|
-
return self.
|
|
583
|
+
return self.post_attention_layernorm
|
|
584
|
+
|
|
585
|
+
def get_pre_feedforward_layernorm(self) -> nn.LayerNorm:
|
|
586
|
+
return self.pre_feedforward_layernorm
|
|
587
|
+
|
|
588
|
+
def get_post_feedforward_layernorm(self) -> nn.LayerNorm:
|
|
589
|
+
return self.post_feedforward_layernorm
|
|
554
590
|
|
|
555
591
|
def get_mlp(self) -> nn.Module:
|
|
556
|
-
return self.
|
|
592
|
+
return self.mlp
|
|
557
593
|
|
|
558
594
|
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
559
595
|
mlp = self.get_mlp()
|
|
@@ -619,6 +655,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
619
655
|
is_sliding: Whether this is sliding window attention
|
|
620
656
|
"""
|
|
621
657
|
|
|
658
|
+
_O_PROJ_ATTRS = ["o_proj", "out_proj", "dense"]
|
|
659
|
+
|
|
622
660
|
def __init__(
|
|
623
661
|
self,
|
|
624
662
|
self_attn,
|
|
@@ -626,20 +664,18 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
626
664
|
is_sliding=False,
|
|
627
665
|
):
|
|
628
666
|
super().__init__()
|
|
629
|
-
self.
|
|
667
|
+
self.config = getattr(self_attn, "config", None)
|
|
630
668
|
self.rbln_config = rbln_config
|
|
631
669
|
self.layer_idx = self_attn.layer_idx
|
|
632
|
-
self.num_heads = (
|
|
633
|
-
|
|
634
|
-
)
|
|
635
|
-
self.head_dim = self._original_mod.head_dim
|
|
670
|
+
self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
|
|
671
|
+
self.head_dim = self_attn.head_dim
|
|
636
672
|
self._phase = "prefill"
|
|
637
|
-
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
673
|
+
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale(self_attn)))
|
|
638
674
|
|
|
639
|
-
if hasattr(
|
|
640
|
-
self.num_key_value_heads =
|
|
641
|
-
elif hasattr(
|
|
642
|
-
self.num_key_value_heads =
|
|
675
|
+
if hasattr(self_attn, "num_key_value_heads"):
|
|
676
|
+
self.num_key_value_heads = self_attn.num_key_value_heads
|
|
677
|
+
elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
|
|
678
|
+
self.num_key_value_heads = self_attn.config.num_key_value_heads
|
|
643
679
|
else:
|
|
644
680
|
self.num_key_value_heads = self.num_heads
|
|
645
681
|
|
|
@@ -649,13 +685,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
649
685
|
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
650
686
|
self.lora_config = rbln_config.lora_config
|
|
651
687
|
|
|
688
|
+
if hasattr(self_attn, "sinks"):
|
|
689
|
+
self.sinks = self_attn.sinks.data[:, None]
|
|
690
|
+
|
|
652
691
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
653
|
-
self.__post_init__()
|
|
692
|
+
self.__post_init__(self_attn)
|
|
654
693
|
|
|
655
694
|
def _init_lora_weights(self):
|
|
656
695
|
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
657
696
|
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
658
|
-
original_linear = getattr(self
|
|
697
|
+
original_linear = getattr(self, proj_name)
|
|
659
698
|
lora_linear = LoRALinear(
|
|
660
699
|
original_linear=original_linear,
|
|
661
700
|
lora_config=self.lora_config,
|
|
@@ -712,16 +751,15 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
712
751
|
else:
|
|
713
752
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
714
753
|
|
|
715
|
-
def __post_init__(self):
|
|
754
|
+
def __post_init__(self, self_attn=None):
|
|
755
|
+
self.q_proj = self_attn.q_proj
|
|
756
|
+
self.k_proj = self_attn.k_proj
|
|
757
|
+
self.v_proj = self_attn.v_proj
|
|
758
|
+
self.o_proj = _get_attr_from_candidates(self_attn, self._O_PROJ_ATTRS)
|
|
759
|
+
|
|
716
760
|
# Initialize LoRA weights if configured, which will replace linear layers
|
|
717
761
|
if self.lora_config:
|
|
718
762
|
self._init_lora_weights()
|
|
719
|
-
else:
|
|
720
|
-
# Use original linear layers if no LoRA
|
|
721
|
-
self.q_proj = self._original_mod.q_proj
|
|
722
|
-
self.k_proj = self._original_mod.k_proj
|
|
723
|
-
self.v_proj = self._original_mod.v_proj
|
|
724
|
-
self.o_proj = self._original_mod.o_proj
|
|
725
763
|
|
|
726
764
|
def projection(
|
|
727
765
|
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
@@ -752,8 +790,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
752
790
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
753
791
|
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
754
792
|
|
|
755
|
-
def get_attn_scale(self):
|
|
756
|
-
return 1 / math.sqrt(
|
|
793
|
+
def get_attn_scale(self, self_attn):
|
|
794
|
+
return 1 / math.sqrt(self_attn.head_dim)
|
|
757
795
|
|
|
758
796
|
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
759
797
|
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
@@ -810,6 +848,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
810
848
|
block_size=self.kvcache_block_size,
|
|
811
849
|
k_scale=k_scale,
|
|
812
850
|
v_scale=v_scale,
|
|
851
|
+
s_aux=getattr(self, "sinks", None),
|
|
813
852
|
)
|
|
814
853
|
|
|
815
854
|
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
@@ -882,6 +921,7 @@ class AttentionOp(nn.Module):
|
|
|
882
921
|
block_size: int,
|
|
883
922
|
k_scale: Optional[torch.Tensor] = None,
|
|
884
923
|
v_scale: Optional[torch.Tensor] = None,
|
|
924
|
+
s_aux: Optional[torch.Tensor] = None,
|
|
885
925
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
886
926
|
"""Compute attention with static shapes and explicit cache management.
|
|
887
927
|
|
|
@@ -898,6 +938,7 @@ class AttentionOp(nn.Module):
|
|
|
898
938
|
block_size: Block size for paged attention
|
|
899
939
|
k_scale: Scale applied to key
|
|
900
940
|
v_scale: Scale applied to value
|
|
941
|
+
s_aux: Auxiliary states for attention sinks
|
|
901
942
|
|
|
902
943
|
Returns:
|
|
903
944
|
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
@@ -953,6 +994,9 @@ class AttentionOp(nn.Module):
|
|
|
953
994
|
op_args["k_scale"] = k_scale
|
|
954
995
|
op_args["v_scale"] = v_scale
|
|
955
996
|
|
|
997
|
+
if s_aux is not None:
|
|
998
|
+
op_args["s_aux"] = s_aux
|
|
999
|
+
|
|
956
1000
|
attn_op_name = self.get_attn_op_name()
|
|
957
1001
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
958
1002
|
if attn_op is None:
|
|
@@ -1017,6 +1061,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1017
1061
|
block_size,
|
|
1018
1062
|
k_scale=None,
|
|
1019
1063
|
v_scale=None,
|
|
1064
|
+
s_aux=None,
|
|
1020
1065
|
):
|
|
1021
1066
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1022
1067
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1070,6 +1115,9 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1070
1115
|
op_args["k_scale"] = k_scale
|
|
1071
1116
|
op_args["v_scale"] = v_scale
|
|
1072
1117
|
|
|
1118
|
+
if s_aux is not None:
|
|
1119
|
+
op_args["s_aux"] = s_aux
|
|
1120
|
+
|
|
1073
1121
|
attn_op_name = self.get_attn_op_name()
|
|
1074
1122
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1075
1123
|
if attn_op is None:
|
|
@@ -1122,6 +1170,7 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1122
1170
|
block_size: int,
|
|
1123
1171
|
k_scale: Optional[torch.Tensor] = None,
|
|
1124
1172
|
v_scale: Optional[torch.Tensor] = None,
|
|
1173
|
+
s_aux: Optional[torch.Tensor] = None,
|
|
1125
1174
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1126
1175
|
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1127
1176
|
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
@@ -1165,6 +1214,11 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1165
1214
|
op_args["is_bidirectional"] = True
|
|
1166
1215
|
else:
|
|
1167
1216
|
op_args["is_bidirectional"] = False
|
|
1217
|
+
elif self.phase == "decode":
|
|
1218
|
+
op_args["attn_mask"] = attn_mask
|
|
1219
|
+
|
|
1220
|
+
if s_aux is not None:
|
|
1221
|
+
op_args["s_aux"] = s_aux
|
|
1168
1222
|
|
|
1169
1223
|
attn_op_name = self.get_attn_op_name()
|
|
1170
1224
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
@@ -1194,7 +1248,7 @@ class RotaryEmbedding(nn.Module):
|
|
|
1194
1248
|
else:
|
|
1195
1249
|
rope_type = "default"
|
|
1196
1250
|
|
|
1197
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1251
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, "cpu", max_seq_len_cached)
|
|
1198
1252
|
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1199
1253
|
cache_position_expanded = cache_position[:, None]
|
|
1200
1254
|
|
|
@@ -1271,3 +1325,22 @@ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tu
|
|
|
1271
1325
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1272
1326
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1273
1327
|
return query_states, key_states
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
def _get_attr_from_candidates(
|
|
1331
|
+
src: object,
|
|
1332
|
+
candidates: Optional[List[str]] = None,
|
|
1333
|
+
):
|
|
1334
|
+
"""
|
|
1335
|
+
Get an attribute from a list of candidate names.
|
|
1336
|
+
|
|
1337
|
+
- If `candidates` is None, this attribute is treated as optional and returns None.
|
|
1338
|
+
- Otherwise, returns `getattr(src, name)` for the first `name` in `candidates` that exists on `src`.
|
|
1339
|
+
- Raises AttributeError if `candidates` is provided but none of the names exist on `src`.
|
|
1340
|
+
"""
|
|
1341
|
+
if candidates is None:
|
|
1342
|
+
return None
|
|
1343
|
+
for name in candidates:
|
|
1344
|
+
if hasattr(src, name):
|
|
1345
|
+
return getattr(src, name)
|
|
1346
|
+
raise AttributeError(f"None of the attributes {candidates} exist in {src}")
|
|
@@ -177,7 +177,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
177
177
|
dec_attn_mask: torch.Tensor,
|
|
178
178
|
page_table_manager: RBLNPageTableManager,
|
|
179
179
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
180
|
-
config: "PreTrainedConfig" = None,
|
|
180
|
+
config: Optional["PreTrainedConfig"] = None,
|
|
181
181
|
logits_last_dim: Optional[int] = None,
|
|
182
182
|
**kwargs: Any,
|
|
183
183
|
) -> None:
|
|
@@ -391,16 +391,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
391
391
|
# Initialize attention mask for chunked processing
|
|
392
392
|
if self.rbln_config.use_attention_mask:
|
|
393
393
|
if self.rbln_config.use_position_ids:
|
|
394
|
-
chunked_attention_mask = torch.zeros(
|
|
395
|
-
1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
|
|
396
|
-
)
|
|
394
|
+
chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=self.rbln_config.dtype)
|
|
397
395
|
else:
|
|
398
396
|
chunked_attention_mask = torch.zeros(
|
|
399
397
|
1,
|
|
400
398
|
1,
|
|
401
399
|
self.rbln_config.prefill_chunk_size,
|
|
402
400
|
self.rbln_config.max_seq_len,
|
|
403
|
-
dtype=self.rbln_config.
|
|
401
|
+
dtype=self.rbln_config.dtype,
|
|
404
402
|
)
|
|
405
403
|
else:
|
|
406
404
|
chunked_attention_mask = None
|
|
@@ -467,7 +465,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
467
465
|
1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
|
|
468
466
|
logits_last_dim,
|
|
469
467
|
)
|
|
470
|
-
output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.
|
|
468
|
+
output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
|
|
471
469
|
|
|
472
470
|
if self.rbln_config.logits_to_keep == 1:
|
|
473
471
|
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
@@ -486,7 +484,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
486
484
|
self.config.hidden_size,
|
|
487
485
|
)
|
|
488
486
|
output_hidden_states = [
|
|
489
|
-
torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.
|
|
487
|
+
torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
|
|
490
488
|
for _ in range(self.config.num_hidden_layers + 1)
|
|
491
489
|
]
|
|
492
490
|
|