optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from collections import deque
|
|
16
|
+
from typing import Any, Optional
|
|
17
|
+
|
|
18
|
+
import rebel
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
|
|
22
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
23
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
24
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBLNPageTableManager:
|
|
28
|
+
EMPTY_BLOCK = -1
|
|
29
|
+
NO_BLOCKS_ERROR = (
|
|
30
|
+
"No memory blocks are available for allocation. "
|
|
31
|
+
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
|
|
32
|
+
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
|
|
33
|
+
"Using vllm-rbln should fix this issue and enhance inference performance."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def __init__(self, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
37
|
+
self.rbln_config = rbln_config
|
|
38
|
+
self.block_tables = torch.zeros(
|
|
39
|
+
self.rbln_config.batch_size,
|
|
40
|
+
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
41
|
+
dtype=torch.int16,
|
|
42
|
+
).fill_(self.EMPTY_BLOCK)
|
|
43
|
+
self.free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
44
|
+
|
|
45
|
+
def update_block(self, batch_idx: int, block_idx: int):
|
|
46
|
+
"""
|
|
47
|
+
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
48
|
+
"""
|
|
49
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
50
|
+
raise IndexError(
|
|
51
|
+
f"Invalid index(batch_idx={batch_idx}, block_idx={block_idx}): \n \
|
|
52
|
+
BlockTable Shape(batch_axis, block_axis): {self.block_tables.shape}, BlockSize: {self.rbln_config.kvcache_block_size}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
|
|
56
|
+
if self.free_block_pool:
|
|
57
|
+
block = self.free_block_pool.popleft()
|
|
58
|
+
self.block_tables[batch_idx][block_idx] = block
|
|
59
|
+
else:
|
|
60
|
+
raise RuntimeError(self.NO_BLOCKS_ERROR)
|
|
61
|
+
|
|
62
|
+
def replace_empty_block(self, block_tables: torch.Tensor):
|
|
63
|
+
"""
|
|
64
|
+
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
|
65
|
+
"""
|
|
66
|
+
if not torch.any(block_tables == self.EMPTY_BLOCK):
|
|
67
|
+
return block_tables.clone()
|
|
68
|
+
elif self.free_block_pool:
|
|
69
|
+
_free_block = self.free_block_pool[0]
|
|
70
|
+
return torch.where(block_tables == self.EMPTY_BLOCK, _free_block, block_tables)
|
|
71
|
+
else:
|
|
72
|
+
raise RuntimeError(self.NO_BLOCKS_ERROR)
|
|
73
|
+
|
|
74
|
+
def get_block_tables(
|
|
75
|
+
self, cache_position: torch.Tensor, batch_idx: int = None, batch_size: int = None, phase: str = "prefill"
|
|
76
|
+
) -> torch.Tensor:
|
|
77
|
+
"""
|
|
78
|
+
Manages and returns the KV cache block tables.
|
|
79
|
+
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
|
83
|
+
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Updated block tables.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
def get_global_block_tables():
|
|
90
|
+
if not self.rbln_config.use_global_attention:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
if phase == "prefill":
|
|
94
|
+
# Track previously used blocks and return them to the free_block_pool and
|
|
95
|
+
# reset the current batch's block table to empty blocks
|
|
96
|
+
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.EMPTY_BLOCK].tolist()
|
|
97
|
+
self.free_block_pool.extend(prev_blocks)
|
|
98
|
+
self.block_tables[batch_idx].fill_(self.EMPTY_BLOCK)
|
|
99
|
+
|
|
100
|
+
# Get the start (s) and end (e) positions from cache_position and
|
|
101
|
+
# iterate over the cache positions to allocate necessary blocks
|
|
102
|
+
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
103
|
+
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
104
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
|
105
|
+
self.update_block(batch_idx, block_idx)
|
|
106
|
+
|
|
107
|
+
return self.replace_empty_block(self.block_tables[batch_idx])
|
|
108
|
+
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
|
109
|
+
else:
|
|
110
|
+
for b_idx in range(batch_size):
|
|
111
|
+
position = cache_position[b_idx][0].item()
|
|
112
|
+
block_idx = position // self.rbln_config.kvcache_block_size
|
|
113
|
+
self.update_block(b_idx, block_idx)
|
|
114
|
+
|
|
115
|
+
return self.replace_empty_block(self.block_tables)
|
|
116
|
+
|
|
117
|
+
def get_local_block_tables():
|
|
118
|
+
if not self.rbln_config.use_local_attention:
|
|
119
|
+
return None
|
|
120
|
+
else:
|
|
121
|
+
return (
|
|
122
|
+
torch.tensor([batch_idx], dtype=torch.int16)
|
|
123
|
+
if phase == "prefill"
|
|
124
|
+
else torch.arange(batch_size, dtype=torch.int16).view(batch_size, -1)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
return get_global_block_tables(), get_local_block_tables()
|
|
128
|
+
|
|
129
|
+
# Whether block_tables and local_block_tables are provided by the user
|
|
130
|
+
def is_external_block_tables(
|
|
131
|
+
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
|
132
|
+
):
|
|
133
|
+
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
|
134
|
+
return False
|
|
135
|
+
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
|
136
|
+
return False
|
|
137
|
+
elif self.rbln_config.cache_impl == "hybrid":
|
|
138
|
+
if (block_tables is not None) != (local_block_tables is not None):
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
|
141
|
+
)
|
|
142
|
+
elif block_tables is None and local_block_tables is None:
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
return True
|
|
146
|
+
|
|
147
|
+
def get_block_tables_if_needed(
|
|
148
|
+
self,
|
|
149
|
+
batch_size,
|
|
150
|
+
cache_position: torch.Tensor,
|
|
151
|
+
batch_idx: int = None,
|
|
152
|
+
phase: str = "prefill",
|
|
153
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
154
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
155
|
+
):
|
|
156
|
+
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
|
157
|
+
if not is_external_block_tables:
|
|
158
|
+
block_tables, local_block_tables = self.get_block_tables(
|
|
159
|
+
cache_position, batch_idx=batch_idx, batch_size=batch_size, phase=phase
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
return block_tables, local_block_tables, is_external_block_tables
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
166
|
+
mandatory_members = ["main_input_name", "embed_tokens"]
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
runtime: rebel.Runtime,
|
|
171
|
+
phase: str,
|
|
172
|
+
batch_size: int,
|
|
173
|
+
dec_attn_mask: torch.Tensor,
|
|
174
|
+
page_table_manager: RBLNPageTableManager,
|
|
175
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
176
|
+
out_buffers: Optional[torch.Tensor] = None,
|
|
177
|
+
**kwargs: Any,
|
|
178
|
+
) -> None:
|
|
179
|
+
super().__init__(runtime, **kwargs)
|
|
180
|
+
self.phase = phase
|
|
181
|
+
self.batch_size = batch_size
|
|
182
|
+
self.rbln_config = rbln_config
|
|
183
|
+
|
|
184
|
+
# shared resources between prefill and decode phase
|
|
185
|
+
self.dec_attn_mask = dec_attn_mask
|
|
186
|
+
self.page_table_manager = page_table_manager
|
|
187
|
+
|
|
188
|
+
if self.phase == "prefill":
|
|
189
|
+
self.out_buffers = out_buffers
|
|
190
|
+
self.causal_mask = 1 - torch.triu(
|
|
191
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
self.lora_int_ids = None
|
|
195
|
+
|
|
196
|
+
def inputs_embeddings_if_needed(
|
|
197
|
+
self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
|
|
198
|
+
):
|
|
199
|
+
if input_ids is None and inputs_embeds is None:
|
|
200
|
+
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
|
201
|
+
|
|
202
|
+
if self.rbln_config.use_inputs_embeds:
|
|
203
|
+
return self.embed_tokens(input_ids) if inputs_embeds is None else inputs_embeds
|
|
204
|
+
else:
|
|
205
|
+
return input_ids
|
|
206
|
+
|
|
207
|
+
def forward(
|
|
208
|
+
self,
|
|
209
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
210
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
211
|
+
cache_position: torch.Tensor = None,
|
|
212
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
213
|
+
batch_idx: Optional[int] = None,
|
|
214
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
215
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
216
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
217
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
218
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
219
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
220
|
+
):
|
|
221
|
+
inputs = self.inputs_embeddings_if_needed(input_ids, inputs_embeds)
|
|
222
|
+
block_tables, local_block_tables, is_external_block_tables = (
|
|
223
|
+
self.page_table_manager.get_block_tables_if_needed(
|
|
224
|
+
self.batch_size,
|
|
225
|
+
cache_position,
|
|
226
|
+
batch_idx=batch_idx,
|
|
227
|
+
phase=self.phase,
|
|
228
|
+
block_tables=block_tables,
|
|
229
|
+
local_block_tables=local_block_tables,
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
if self.phase == "decode":
|
|
234
|
+
return self.decode_forward(
|
|
235
|
+
inputs,
|
|
236
|
+
cache_position,
|
|
237
|
+
block_tables,
|
|
238
|
+
is_external_block_tables,
|
|
239
|
+
attention_mask=attention_mask,
|
|
240
|
+
position_embed=position_embed,
|
|
241
|
+
position_ids=position_ids,
|
|
242
|
+
local_block_tables=local_block_tables,
|
|
243
|
+
lora_int_ids=lora_int_ids,
|
|
244
|
+
)
|
|
245
|
+
else:
|
|
246
|
+
return self.prefill_forward(
|
|
247
|
+
inputs,
|
|
248
|
+
cache_position,
|
|
249
|
+
attention_mask,
|
|
250
|
+
batch_idx,
|
|
251
|
+
block_tables,
|
|
252
|
+
is_external_block_tables=is_external_block_tables,
|
|
253
|
+
position_embed=position_embed,
|
|
254
|
+
token_type_ids=token_type_ids,
|
|
255
|
+
local_block_tables=local_block_tables,
|
|
256
|
+
lora_int_ids=lora_int_ids,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def decode_forward(
|
|
260
|
+
self,
|
|
261
|
+
inputs: torch.Tensor,
|
|
262
|
+
cache_position: torch.Tensor = None,
|
|
263
|
+
block_tables: torch.Tensor = None,
|
|
264
|
+
is_external_block_tables: bool = None,
|
|
265
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
266
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
267
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
268
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
269
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
270
|
+
) -> torch.FloatTensor:
|
|
271
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
272
|
+
if self.lora_int_ids is None:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
"lora_int_id is required when using LoRA. "
|
|
275
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
lora_int_ids = self.lora_int_ids
|
|
279
|
+
|
|
280
|
+
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
281
|
+
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
282
|
+
|
|
283
|
+
if self.batch_size != cache_position.shape[0]:
|
|
284
|
+
raise RuntimeError(
|
|
285
|
+
f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.batch_size}."
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
289
|
+
for b_idx in range(self.batch_size):
|
|
290
|
+
decoding_step = cache_position[b_idx].item()
|
|
291
|
+
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
292
|
+
raise ValueError(
|
|
293
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if is_external_block_tables:
|
|
297
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
|
298
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
299
|
+
else:
|
|
300
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
301
|
+
|
|
302
|
+
attention_mask = self.dec_attn_mask
|
|
303
|
+
|
|
304
|
+
logits = super().forward(
|
|
305
|
+
inputs,
|
|
306
|
+
cache_position,
|
|
307
|
+
block_tables,
|
|
308
|
+
local_block_tables,
|
|
309
|
+
position_embed,
|
|
310
|
+
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
311
|
+
position_ids if self.rbln_config.use_position_ids else None,
|
|
312
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
return RBLNDecoderOnlyOutput(logits=logits)
|
|
316
|
+
|
|
317
|
+
def _prepare_prefill_inputs(
|
|
318
|
+
self,
|
|
319
|
+
inputs: torch.Tensor,
|
|
320
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
321
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
322
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
323
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
324
|
+
):
|
|
325
|
+
"""
|
|
326
|
+
Prepare inputs for prefill phase.
|
|
327
|
+
"""
|
|
328
|
+
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
329
|
+
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
330
|
+
if attention_mask is not None:
|
|
331
|
+
inputs = inputs[:, attention_mask.bool()]
|
|
332
|
+
position_embed = None if position_embed is None else position_embed[:, :, :, attention_mask.bool(), :]
|
|
333
|
+
token_type_ids = None if token_type_ids is None else token_type_ids[:, attention_mask.bool()]
|
|
334
|
+
|
|
335
|
+
query_length = inputs.shape[1]
|
|
336
|
+
if query_length > self.rbln_config.max_seq_len:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Initialize attention mask for chunked processing
|
|
342
|
+
chunked_attention_mask = (
|
|
343
|
+
torch.zeros(
|
|
344
|
+
1,
|
|
345
|
+
1,
|
|
346
|
+
self.rbln_config.prefill_chunk_size,
|
|
347
|
+
self.rbln_config.max_seq_len,
|
|
348
|
+
dtype=self.rbln_config.torch_dtype,
|
|
349
|
+
)
|
|
350
|
+
if self.rbln_config.use_attention_mask
|
|
351
|
+
else None
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
cache_position = (
|
|
355
|
+
torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
|
|
356
|
+
)
|
|
357
|
+
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
358
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
359
|
+
if padding_size > 0:
|
|
360
|
+
inputs = (
|
|
361
|
+
F.pad(inputs, (0, 0, 0, padding_size))
|
|
362
|
+
if self.rbln_config.use_inputs_embeds
|
|
363
|
+
else F.pad(inputs, (0, padding_size))
|
|
364
|
+
)
|
|
365
|
+
position_embed = F.pad(position_embed, (0, 0, 0, padding_size)) if position_embed is not None else None
|
|
366
|
+
token_type_ids = F.pad(token_type_ids, (0, padding_size), value=-1) if token_type_ids is not None else None
|
|
367
|
+
cache_position = F.pad(cache_position, (0, padding_size))
|
|
368
|
+
|
|
369
|
+
# Overwrite position_ids and padded_cache_lengths
|
|
370
|
+
position_ids = cache_position.clone() if self.rbln_config.use_position_ids else None
|
|
371
|
+
padded_cache_lengths = 0
|
|
372
|
+
|
|
373
|
+
return (
|
|
374
|
+
inputs,
|
|
375
|
+
cache_position,
|
|
376
|
+
chunked_attention_mask,
|
|
377
|
+
position_ids,
|
|
378
|
+
position_embed,
|
|
379
|
+
padded_cache_lengths,
|
|
380
|
+
query_length,
|
|
381
|
+
token_type_ids,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
def prefill_forward(
|
|
385
|
+
self,
|
|
386
|
+
inputs: torch.Tensor,
|
|
387
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
388
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
389
|
+
batch_idx: Optional[int] = None,
|
|
390
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
391
|
+
is_external_block_tables: Optional[bool] = None,
|
|
392
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
393
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
394
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
395
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
396
|
+
) -> torch.FloatTensor:
|
|
397
|
+
"""
|
|
398
|
+
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
399
|
+
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
400
|
+
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
401
|
+
"""
|
|
402
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
403
|
+
if self.lora_int_ids is None:
|
|
404
|
+
raise ValueError(
|
|
405
|
+
"lora_int_id is required when using LoRA. "
|
|
406
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if batch_idx is not None:
|
|
410
|
+
lora_int_ids = self.lora_int_ids[batch_idx : batch_idx + 1].clone()
|
|
411
|
+
else:
|
|
412
|
+
lora_int_ids = self.lora_int_ids.clone()
|
|
413
|
+
|
|
414
|
+
(
|
|
415
|
+
inputs,
|
|
416
|
+
cache_position,
|
|
417
|
+
chunked_attention_mask,
|
|
418
|
+
position_ids,
|
|
419
|
+
position_embed,
|
|
420
|
+
padded_cache_lengths,
|
|
421
|
+
query_length,
|
|
422
|
+
token_type_ids,
|
|
423
|
+
) = self._prepare_prefill_inputs(
|
|
424
|
+
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
|
|
428
|
+
prefix_cached_len = cache_position[0][0].item()
|
|
429
|
+
if prefix_cached_len > 0:
|
|
430
|
+
if prefix_cached_len % self.rbln_config.prefill_chunk_size != 0:
|
|
431
|
+
raise NotImplementedError(
|
|
432
|
+
"Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
|
|
433
|
+
)
|
|
434
|
+
if self.rbln_config.use_attention_mask:
|
|
435
|
+
chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
|
|
436
|
+
|
|
437
|
+
# Process input in chunks of size `prefill_chunk_size`
|
|
438
|
+
output_logits = []
|
|
439
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
440
|
+
s, e = step, step + self.rbln_config.prefill_chunk_size
|
|
441
|
+
# Extract the current chunk of inputs, cache positions, position ids, and position embeddings
|
|
442
|
+
input_chunk = inputs[:, s:e]
|
|
443
|
+
cache_pos_chunk = cache_position[:, s:e]
|
|
444
|
+
position_ids_chunk = position_ids[:, s:e] if self.rbln_config.use_position_ids else None
|
|
445
|
+
position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
|
|
446
|
+
|
|
447
|
+
# Update attention mask to ensure proper causal behavior
|
|
448
|
+
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
449
|
+
if step > 0: # update previous chunk
|
|
450
|
+
chunked_attention_mask[
|
|
451
|
+
:,
|
|
452
|
+
:,
|
|
453
|
+
:,
|
|
454
|
+
s - self.rbln_config.prefill_chunk_size + prefix_cached_len : e
|
|
455
|
+
- self.rbln_config.prefill_chunk_size
|
|
456
|
+
+ prefix_cached_len,
|
|
457
|
+
] = 1
|
|
458
|
+
chunked_attention_mask[:, :, :, s + prefix_cached_len : e + prefix_cached_len] = self.causal_mask
|
|
459
|
+
|
|
460
|
+
# Calculate query position if needed
|
|
461
|
+
if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
|
|
462
|
+
query_position = (
|
|
463
|
+
torch.tensor((query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16)
|
|
464
|
+
if e >= query_length
|
|
465
|
+
else torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
query_position = None
|
|
469
|
+
|
|
470
|
+
# Forward pass for the current chunk
|
|
471
|
+
output_logit = super().forward(
|
|
472
|
+
input_chunk,
|
|
473
|
+
cache_pos_chunk,
|
|
474
|
+
block_tables,
|
|
475
|
+
local_block_tables,
|
|
476
|
+
position_embed_chunk,
|
|
477
|
+
query_position,
|
|
478
|
+
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
479
|
+
position_ids_chunk,
|
|
480
|
+
lora_int_ids if self.rbln_config.use_lora else None,
|
|
481
|
+
out=self.out_buffers,
|
|
482
|
+
)
|
|
483
|
+
output_logits.append(output_logit)
|
|
484
|
+
|
|
485
|
+
# Aggregate output_logits
|
|
486
|
+
output_logits = torch.concat(output_logits, dim=-2)
|
|
487
|
+
if self.rbln_config.logits_to_keep > 0:
|
|
488
|
+
output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
|
|
489
|
+
else:
|
|
490
|
+
output_logits = output_logits[:, :query_length, :]
|
|
491
|
+
# index copy for masked output_logits
|
|
492
|
+
if attention_mask is not None:
|
|
493
|
+
new_output_logits = torch.full(
|
|
494
|
+
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
495
|
+
fill_value=1e-10,
|
|
496
|
+
dtype=output_logits.dtype,
|
|
497
|
+
)
|
|
498
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
499
|
+
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
500
|
+
|
|
501
|
+
output_logits = new_output_logits
|
|
502
|
+
|
|
503
|
+
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
504
|
+
if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
505
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
|
506
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
507
|
+
|
|
508
|
+
return RBLNDecoderOnlyOutput(logits=output_logits, padded_cache_lengths=padded_cache_lengths)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers import GenerationConfig
|
|
19
|
+
from transformers.generation.utils import GenerationMixin
|
|
20
|
+
from transformers.modeling_outputs import ModelOutput
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
|
|
28
|
+
_supports_cache_class = False # Needed for GenerationMixin
|
|
29
|
+
_is_stateful = False # Needed for GenerationMixin
|
|
30
|
+
|
|
31
|
+
def _reorder_cache(self, past_key_values, beam_idx):
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
def prepare_inputs_for_generation(
|
|
35
|
+
self,
|
|
36
|
+
input_ids: torch.LongTensor,
|
|
37
|
+
generate_idx: Optional[torch.Tensor] = None,
|
|
38
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
39
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
40
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
model_inputs = {}
|
|
44
|
+
is_prefill_phase = generate_idx is None
|
|
45
|
+
|
|
46
|
+
if is_prefill_phase:
|
|
47
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
48
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
49
|
+
cache_position = None
|
|
50
|
+
position_ids = None
|
|
51
|
+
else:
|
|
52
|
+
if inputs_embeds is not None:
|
|
53
|
+
# if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
|
|
54
|
+
inputs_embeds = None
|
|
55
|
+
|
|
56
|
+
input_ids = input_ids[:, -1:]
|
|
57
|
+
position_ids = generate_idx
|
|
58
|
+
cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
|
|
59
|
+
generate_idx = generate_idx + 1
|
|
60
|
+
model_inputs.update({"input_ids": input_ids})
|
|
61
|
+
|
|
62
|
+
if inputs_embeds is not None:
|
|
63
|
+
if self.rbln_config.use_inputs_embeds:
|
|
64
|
+
model_inputs.update({"inputs_embeds": inputs_embeds})
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
model_inputs.update({"input_ids": input_ids})
|
|
71
|
+
|
|
72
|
+
model_inputs.update(
|
|
73
|
+
{
|
|
74
|
+
"attention_mask": attention_mask,
|
|
75
|
+
"cache_position": cache_position,
|
|
76
|
+
"generate_idx": generate_idx,
|
|
77
|
+
"position_ids": position_ids,
|
|
78
|
+
"padded_cache_lengths": padded_cache_lengths,
|
|
79
|
+
}
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return model_inputs
|
|
83
|
+
|
|
84
|
+
def _update_model_kwargs_for_generation(
|
|
85
|
+
self, outputs: "RBLNDecoderOnlyOutput", model_kwargs: Dict[str, Any], **kwargs
|
|
86
|
+
) -> Dict[str, Any]:
|
|
87
|
+
# update generate_idx
|
|
88
|
+
model_kwargs["generate_idx"] = outputs.generate_idx
|
|
89
|
+
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
|
90
|
+
return model_kwargs
|
|
91
|
+
|
|
92
|
+
def generate(
|
|
93
|
+
self,
|
|
94
|
+
input_ids: torch.LongTensor,
|
|
95
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
96
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
97
|
+
**kwargs,
|
|
98
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
99
|
+
"""
|
|
100
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
101
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
105
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
106
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
107
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
108
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
109
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A ModelOutput (if return_dict_in_generate=True or when config.return_dict_in_generate=True) or a torch.LongTensor.
|
|
113
|
+
"""
|
|
114
|
+
if generation_config is not None:
|
|
115
|
+
kwargs["generation_config"] = generation_config
|
|
116
|
+
if attention_mask is not None:
|
|
117
|
+
kwargs["attention_mask"] = attention_mask
|
|
118
|
+
|
|
119
|
+
return super().generate(input_ids, **kwargs)
|