optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from collections import deque
|
|
16
|
-
from typing import Any, Optional
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
17
17
|
|
|
18
18
|
import rebel
|
|
19
19
|
import torch
|
|
@@ -24,6 +24,10 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
|
24
24
|
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers.configuration_utils import PreTrainedConfig
|
|
29
|
+
|
|
30
|
+
|
|
27
31
|
class RBLNPageTableManager:
|
|
28
32
|
EMPTY_BLOCK = -1
|
|
29
33
|
NO_BLOCKS_ERROR = (
|
|
@@ -46,6 +50,12 @@ class RBLNPageTableManager:
|
|
|
46
50
|
"""
|
|
47
51
|
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
48
52
|
"""
|
|
53
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
54
|
+
raise IndexError(
|
|
55
|
+
f"Invalid index(batch_idx={batch_idx}, block_idx={block_idx}): \n \
|
|
56
|
+
BlockTable Shape(batch_axis, block_axis): {self.block_tables.shape}, BlockSize: {self.rbln_config.kvcache_block_size}"
|
|
57
|
+
)
|
|
58
|
+
|
|
49
59
|
if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
|
|
50
60
|
if self.free_block_pool:
|
|
51
61
|
block = self.free_block_pool.popleft()
|
|
@@ -96,8 +106,6 @@ class RBLNPageTableManager:
|
|
|
96
106
|
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
97
107
|
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
98
108
|
block_idx = position // self.rbln_config.kvcache_block_size
|
|
99
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
100
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
|
101
109
|
self.update_block(batch_idx, block_idx)
|
|
102
110
|
|
|
103
111
|
return self.replace_empty_block(self.block_tables[batch_idx])
|
|
@@ -169,20 +177,23 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
169
177
|
dec_attn_mask: torch.Tensor,
|
|
170
178
|
page_table_manager: RBLNPageTableManager,
|
|
171
179
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
172
|
-
|
|
180
|
+
config: Optional["PreTrainedConfig"] = None,
|
|
181
|
+
logits_last_dim: Optional[int] = None,
|
|
173
182
|
**kwargs: Any,
|
|
174
183
|
) -> None:
|
|
175
184
|
super().__init__(runtime, **kwargs)
|
|
176
185
|
self.phase = phase
|
|
177
186
|
self.batch_size = batch_size
|
|
178
187
|
self.rbln_config = rbln_config
|
|
188
|
+
self.config = config
|
|
189
|
+
self.logits_last_dim = logits_last_dim
|
|
179
190
|
|
|
180
191
|
# shared resources between prefill and decode phase
|
|
181
192
|
self.dec_attn_mask = dec_attn_mask
|
|
182
193
|
self.page_table_manager = page_table_manager
|
|
194
|
+
self.out_buffers = None
|
|
183
195
|
|
|
184
196
|
if self.phase == "prefill":
|
|
185
|
-
self.out_buffers = out_buffers
|
|
186
197
|
self.causal_mask = 1 - torch.triu(
|
|
187
198
|
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
188
199
|
)
|
|
@@ -276,28 +287,48 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
276
287
|
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
277
288
|
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
278
289
|
|
|
279
|
-
|
|
290
|
+
batch_size = inputs.shape[0]
|
|
291
|
+
if batch_size != self.batch_size:
|
|
280
292
|
raise RuntimeError(
|
|
281
|
-
f"
|
|
293
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if batch_size != cache_position.shape[0]:
|
|
297
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
298
|
+
|
|
299
|
+
if self.rbln_config.use_local_attention:
|
|
300
|
+
local_block_tables = (
|
|
301
|
+
local_block_tables
|
|
302
|
+
if local_block_tables is not None
|
|
303
|
+
else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
|
|
282
304
|
)
|
|
283
305
|
|
|
284
306
|
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
285
|
-
for b_idx in range(
|
|
307
|
+
for b_idx in range(batch_size):
|
|
286
308
|
decoding_step = cache_position[b_idx].item()
|
|
287
309
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
288
310
|
raise ValueError(
|
|
289
311
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
290
312
|
)
|
|
291
313
|
|
|
292
|
-
if
|
|
293
|
-
self.dec_attn_mask[b_idx]
|
|
294
|
-
|
|
314
|
+
if self.rbln_config.use_position_ids:
|
|
315
|
+
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
316
|
+
|
|
317
|
+
if self.batch_size < block_tables.shape[0]:
|
|
318
|
+
block_tables = block_tables[: self.batch_size]
|
|
319
|
+
|
|
320
|
+
if self.dec_attn_mask is not None and self.batch_size < self.dec_attn_mask.shape[0]:
|
|
321
|
+
self.dec_attn_mask = self.dec_attn_mask[: self.batch_size]
|
|
295
322
|
else:
|
|
296
|
-
|
|
323
|
+
if is_external_block_tables:
|
|
324
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
|
325
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
326
|
+
else:
|
|
327
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
297
328
|
|
|
298
329
|
attention_mask = self.dec_attn_mask
|
|
299
330
|
|
|
300
|
-
|
|
331
|
+
outputs = super().forward(
|
|
301
332
|
inputs,
|
|
302
333
|
cache_position,
|
|
303
334
|
block_tables,
|
|
@@ -306,15 +337,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
306
337
|
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
307
338
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
308
339
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
340
|
+
out=self.out_buffers,
|
|
309
341
|
)
|
|
310
342
|
|
|
311
|
-
|
|
343
|
+
if self.rbln_config.output_hidden_states:
|
|
344
|
+
return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
|
|
345
|
+
else:
|
|
346
|
+
return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)
|
|
312
347
|
|
|
313
348
|
def _prepare_prefill_inputs(
|
|
314
349
|
self,
|
|
315
350
|
inputs: torch.Tensor,
|
|
316
351
|
cache_position: Optional[torch.Tensor] = None,
|
|
317
352
|
attention_mask: Optional[torch.Tensor] = None,
|
|
353
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
318
354
|
position_embed: Optional[torch.Tensor] = None,
|
|
319
355
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
320
356
|
):
|
|
@@ -324,9 +360,27 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
324
360
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
325
361
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
326
362
|
if attention_mask is not None:
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
363
|
+
if attention_mask.dim() != 1:
|
|
364
|
+
raise ValueError("attention_mask must be a 1D tensor.")
|
|
365
|
+
|
|
366
|
+
mask_bool = attention_mask.to(dtype=torch.bool)
|
|
367
|
+
if (~mask_bool).any():
|
|
368
|
+
indice_one = torch.nonzero(mask_bool, as_tuple=False)
|
|
369
|
+
if indice_one.numel() == 0:
|
|
370
|
+
raise ValueError("attention_mask with padding must include at least one real token.")
|
|
371
|
+
first_one_idx, last_one_idx = int(indice_one[0].item()), int(indice_one[-1].item())
|
|
372
|
+
if last_one_idx - first_one_idx + 1 != mask_bool.sum():
|
|
373
|
+
raise ValueError(
|
|
374
|
+
"attention_mask must group all 1s together (e.g. 000111 or 1111000). "
|
|
375
|
+
"Zeros between real tokens like 101010 are not supported."
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if self.rbln_config.can_generate and not mask_bool[first_one_idx:].all():
|
|
379
|
+
raise ValueError("attention_mask must be left padded for generation.")
|
|
380
|
+
|
|
381
|
+
inputs = inputs[:, mask_bool]
|
|
382
|
+
position_embed = None if position_embed is None else position_embed[:, :, :, mask_bool, :]
|
|
383
|
+
token_type_ids = None if token_type_ids is None else token_type_ids[:, mask_bool]
|
|
330
384
|
|
|
331
385
|
query_length = inputs.shape[1]
|
|
332
386
|
if query_length > self.rbln_config.max_seq_len:
|
|
@@ -335,17 +389,19 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
335
389
|
)
|
|
336
390
|
|
|
337
391
|
# Initialize attention mask for chunked processing
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
1,
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
392
|
+
if self.rbln_config.use_attention_mask:
|
|
393
|
+
if self.rbln_config.use_position_ids:
|
|
394
|
+
chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=self.rbln_config.dtype)
|
|
395
|
+
else:
|
|
396
|
+
chunked_attention_mask = torch.zeros(
|
|
397
|
+
1,
|
|
398
|
+
1,
|
|
399
|
+
self.rbln_config.prefill_chunk_size,
|
|
400
|
+
self.rbln_config.max_seq_len,
|
|
401
|
+
dtype=self.rbln_config.dtype,
|
|
402
|
+
)
|
|
403
|
+
else:
|
|
404
|
+
chunked_attention_mask = None
|
|
349
405
|
|
|
350
406
|
cache_position = (
|
|
351
407
|
torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
|
|
@@ -363,7 +419,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
363
419
|
cache_position = F.pad(cache_position, (0, padding_size))
|
|
364
420
|
|
|
365
421
|
# Overwrite position_ids and padded_cache_lengths
|
|
366
|
-
|
|
422
|
+
if self.rbln_config.use_position_ids and position_ids is None:
|
|
423
|
+
position_ids = cache_position.clone()
|
|
424
|
+
else:
|
|
425
|
+
position_ids = position_ids
|
|
426
|
+
|
|
367
427
|
padded_cache_lengths = 0
|
|
368
428
|
|
|
369
429
|
return (
|
|
@@ -377,6 +437,68 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
377
437
|
token_type_ids,
|
|
378
438
|
)
|
|
379
439
|
|
|
440
|
+
def _prepare_prefill_outputs(
|
|
441
|
+
self,
|
|
442
|
+
query_length: int,
|
|
443
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
444
|
+
):
|
|
445
|
+
# Prepare out buffers
|
|
446
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
447
|
+
padded_input_length = query_length + padding_size
|
|
448
|
+
padded_mask_length = (
|
|
449
|
+
attention_mask.shape[-1] + padding_size if attention_mask is not None else padded_input_length
|
|
450
|
+
)
|
|
451
|
+
out_buffers = [[] for _ in range(padded_input_length // self.rbln_config.prefill_chunk_size)]
|
|
452
|
+
|
|
453
|
+
valid_start_index = (
|
|
454
|
+
int(torch.nonzero(attention_mask, as_tuple=False)[0][0].item()) if attention_mask is not None else 0
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
if self.logits_last_dim is None:
|
|
458
|
+
logits_last_dim = self.config.vocab_size if self.rbln_config.can_generate else self.config.hidden_size
|
|
459
|
+
else:
|
|
460
|
+
logits_last_dim = self.logits_last_dim
|
|
461
|
+
|
|
462
|
+
# Prepare logits buffer
|
|
463
|
+
logits_size = (
|
|
464
|
+
1,
|
|
465
|
+
1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
|
|
466
|
+
logits_last_dim,
|
|
467
|
+
)
|
|
468
|
+
output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
|
|
469
|
+
|
|
470
|
+
if self.rbln_config.logits_to_keep == 1:
|
|
471
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
472
|
+
out_buffers[i].append(output_logits)
|
|
473
|
+
else:
|
|
474
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
475
|
+
s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
|
|
476
|
+
out_buffers[i].append(output_logits[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size])
|
|
477
|
+
|
|
478
|
+
# Prepare output hidden states
|
|
479
|
+
output_hidden_states = None
|
|
480
|
+
if self.rbln_config.output_hidden_states:
|
|
481
|
+
hidden_states_size = (
|
|
482
|
+
1,
|
|
483
|
+
padded_mask_length,
|
|
484
|
+
self.config.hidden_size,
|
|
485
|
+
)
|
|
486
|
+
output_hidden_states = [
|
|
487
|
+
torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.dtype)
|
|
488
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
489
|
+
]
|
|
490
|
+
|
|
491
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
492
|
+
s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
|
|
493
|
+
out_buffers[i].extend(
|
|
494
|
+
[
|
|
495
|
+
hidden_states_buffer[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size]
|
|
496
|
+
for hidden_states_buffer in output_hidden_states
|
|
497
|
+
]
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
return out_buffers, output_logits, output_hidden_states
|
|
501
|
+
|
|
380
502
|
def prefill_forward(
|
|
381
503
|
self,
|
|
382
504
|
inputs: torch.Tensor,
|
|
@@ -385,6 +507,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
385
507
|
batch_idx: Optional[int] = None,
|
|
386
508
|
block_tables: Optional[torch.Tensor] = None,
|
|
387
509
|
is_external_block_tables: Optional[bool] = None,
|
|
510
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
388
511
|
position_embed: Optional[torch.Tensor] = None,
|
|
389
512
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
390
513
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
@@ -417,9 +540,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
417
540
|
query_length,
|
|
418
541
|
token_type_ids,
|
|
419
542
|
) = self._prepare_prefill_inputs(
|
|
420
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
543
|
+
inputs, cache_position, attention_mask, position_ids, position_embed, token_type_ids=token_type_ids
|
|
421
544
|
)
|
|
422
545
|
|
|
546
|
+
out_buffers, output_logits, output_hidden_states = self._prepare_prefill_outputs(query_length, attention_mask)
|
|
547
|
+
|
|
423
548
|
# Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
|
|
424
549
|
prefix_cached_len = cache_position[0][0].item()
|
|
425
550
|
if prefix_cached_len > 0:
|
|
@@ -428,11 +553,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
428
553
|
"Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
|
|
429
554
|
)
|
|
430
555
|
if self.rbln_config.use_attention_mask:
|
|
431
|
-
|
|
556
|
+
if self.rbln_config.use_position_ids:
|
|
557
|
+
chunked_attention_mask[:, :prefix_cached_len] = 1
|
|
558
|
+
else:
|
|
559
|
+
chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
|
|
432
560
|
|
|
433
561
|
# Process input in chunks of size `prefill_chunk_size`
|
|
434
|
-
|
|
435
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
562
|
+
for i, step in enumerate(range(0, query_length, self.rbln_config.prefill_chunk_size)):
|
|
436
563
|
s, e = step, step + self.rbln_config.prefill_chunk_size
|
|
437
564
|
# Extract the current chunk of inputs, cache positions, position ids, and position embeddings
|
|
438
565
|
input_chunk = inputs[:, s:e]
|
|
@@ -441,17 +568,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
441
568
|
position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
|
|
442
569
|
|
|
443
570
|
# Update attention mask to ensure proper causal behavior
|
|
444
|
-
if self.rbln_config.use_attention_mask
|
|
445
|
-
if
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
571
|
+
if self.rbln_config.use_attention_mask:
|
|
572
|
+
if self.rbln_config.use_position_ids:
|
|
573
|
+
if step > 0: # update previous chunk
|
|
574
|
+
# Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
|
|
575
|
+
prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
|
|
576
|
+
prev_chunk_end = s + prefix_cached_len
|
|
577
|
+
chunked_attention_mask[:, prev_chunk_start:prev_chunk_end] = 1
|
|
578
|
+
|
|
579
|
+
current_chunk_start = s + prefix_cached_len
|
|
580
|
+
current_chunk_end = min(e, query_length) + prefix_cached_len
|
|
581
|
+
if current_chunk_end > current_chunk_start:
|
|
582
|
+
chunked_attention_mask[:, current_chunk_start:current_chunk_end] = 1
|
|
583
|
+
|
|
584
|
+
else:
|
|
585
|
+
if step > 0: # update previous chunk
|
|
586
|
+
# Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
|
|
587
|
+
prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
|
|
588
|
+
prev_chunk_end = s + prefix_cached_len
|
|
589
|
+
chunked_attention_mask[:, :, :, prev_chunk_start:prev_chunk_end] = 1
|
|
590
|
+
|
|
591
|
+
current_chunk_start = s + prefix_cached_len
|
|
592
|
+
current_chunk_end = e + prefix_cached_len
|
|
593
|
+
chunked_attention_mask[:, :, :, current_chunk_start:current_chunk_end] = self.causal_mask
|
|
455
594
|
|
|
456
595
|
# Calculate query position if needed
|
|
457
596
|
if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
|
|
@@ -464,7 +603,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
464
603
|
query_position = None
|
|
465
604
|
|
|
466
605
|
# Forward pass for the current chunk
|
|
467
|
-
|
|
606
|
+
_ = super().forward(
|
|
468
607
|
input_chunk,
|
|
469
608
|
cache_pos_chunk,
|
|
470
609
|
block_tables,
|
|
@@ -474,31 +613,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
474
613
|
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
475
614
|
position_ids_chunk,
|
|
476
615
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
477
|
-
out=
|
|
616
|
+
out=out_buffers[i],
|
|
478
617
|
)
|
|
479
|
-
output_logits.append(output_logit)
|
|
480
618
|
|
|
481
619
|
# Aggregate output_logits
|
|
482
|
-
|
|
483
|
-
if self.rbln_config.logits_to_keep
|
|
484
|
-
output_logits = output_logits
|
|
620
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
621
|
+
if self.rbln_config.logits_to_keep == 1:
|
|
622
|
+
output_logits = output_logits
|
|
623
|
+
elif self.rbln_config.logits_to_keep > 1:
|
|
624
|
+
output_logits = output_logits[:, -padding_size - self.rbln_config.logits_to_keep : -padding_size, :]
|
|
485
625
|
else:
|
|
486
|
-
output_logits = output_logits[:,
|
|
487
|
-
# index copy for masked output_logits
|
|
488
|
-
if attention_mask is not None:
|
|
489
|
-
new_output_logits = torch.full(
|
|
490
|
-
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
491
|
-
fill_value=1e-10,
|
|
492
|
-
dtype=output_logits.dtype,
|
|
493
|
-
)
|
|
494
|
-
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
495
|
-
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
626
|
+
output_logits = output_logits[:, :-padding_size, :]
|
|
496
627
|
|
|
497
|
-
|
|
628
|
+
all_hidden_states = None
|
|
629
|
+
if self.rbln_config.output_hidden_states:
|
|
630
|
+
all_hidden_states = [
|
|
631
|
+
output_hidden_state[:, :-padding_size, :] for output_hidden_state in output_hidden_states
|
|
632
|
+
]
|
|
633
|
+
all_hidden_states = tuple(all_hidden_states)
|
|
498
634
|
|
|
499
635
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
500
636
|
if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
501
|
-
self.
|
|
502
|
-
|
|
637
|
+
if self.rbln_config.use_position_ids:
|
|
638
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
639
|
+
else:
|
|
640
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
|
641
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
503
642
|
|
|
504
|
-
return RBLNDecoderOnlyOutput(
|
|
643
|
+
return RBLNDecoderOnlyOutput(
|
|
644
|
+
logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
|
|
645
|
+
)
|
|
@@ -12,10 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
|
+
from transformers import GenerationConfig
|
|
18
19
|
from transformers.generation.utils import GenerationMixin
|
|
20
|
+
from transformers.modeling_outputs import ModelOutput
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
if TYPE_CHECKING:
|
|
@@ -91,20 +93,26 @@ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
|
|
|
91
93
|
self,
|
|
92
94
|
input_ids: torch.LongTensor,
|
|
93
95
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
94
|
-
|
|
96
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
95
97
|
**kwargs,
|
|
96
|
-
):
|
|
98
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
97
99
|
"""
|
|
98
100
|
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
101
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
99
102
|
|
|
100
103
|
Args:
|
|
101
|
-
input_ids: The input ids to the model.
|
|
102
|
-
attention_mask: The attention mask to the model.
|
|
103
|
-
|
|
104
|
-
|
|
104
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
105
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
106
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
107
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
108
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
109
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A ModelOutput (if return_dict_in_generate=True or when config.return_dict_in_generate=True) or a torch.LongTensor.
|
|
105
113
|
"""
|
|
106
|
-
if
|
|
107
|
-
kwargs["
|
|
114
|
+
if generation_config is not None:
|
|
115
|
+
kwargs["generation_config"] = generation_config
|
|
108
116
|
if attention_mask is not None:
|
|
109
117
|
kwargs["attention_mask"] = attention_mask
|
|
110
118
|
|
|
@@ -142,7 +142,7 @@ class LoRALinear(nn.Module):
|
|
|
142
142
|
padded_lora_a = []
|
|
143
143
|
padded_lora_b = []
|
|
144
144
|
|
|
145
|
-
for
|
|
145
|
+
for _, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
|
|
146
146
|
current_rank = lora_a.shape[0]
|
|
147
147
|
if current_rank < max_rank:
|
|
148
148
|
# Pad with zeros
|