optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +12 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +16 -6
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +12 -8
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +242 -109
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +1 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +7 -1
- optimum/rbln/utils/runtime_utils.py +32 -0
- optimum/rbln/utils/submodule.py +3 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from collections import deque
|
|
16
|
-
from typing import Any, Optional
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
17
17
|
|
|
18
18
|
import rebel
|
|
19
19
|
import torch
|
|
@@ -24,6 +24,10 @@ from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
|
24
24
|
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers.configuration_utils import PreTrainedConfig
|
|
29
|
+
|
|
30
|
+
|
|
27
31
|
class RBLNPageTableManager:
|
|
28
32
|
EMPTY_BLOCK = -1
|
|
29
33
|
NO_BLOCKS_ERROR = (
|
|
@@ -46,6 +50,12 @@ class RBLNPageTableManager:
|
|
|
46
50
|
"""
|
|
47
51
|
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
48
52
|
"""
|
|
53
|
+
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
54
|
+
raise IndexError(
|
|
55
|
+
f"Invalid index(batch_idx={batch_idx}, block_idx={block_idx}): \n \
|
|
56
|
+
BlockTable Shape(batch_axis, block_axis): {self.block_tables.shape}, BlockSize: {self.rbln_config.kvcache_block_size}"
|
|
57
|
+
)
|
|
58
|
+
|
|
49
59
|
if self.block_tables[batch_idx][block_idx] == self.EMPTY_BLOCK:
|
|
50
60
|
if self.free_block_pool:
|
|
51
61
|
block = self.free_block_pool.popleft()
|
|
@@ -96,8 +106,6 @@ class RBLNPageTableManager:
|
|
|
96
106
|
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
97
107
|
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
98
108
|
block_idx = position // self.rbln_config.kvcache_block_size
|
|
99
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
100
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
|
101
109
|
self.update_block(batch_idx, block_idx)
|
|
102
110
|
|
|
103
111
|
return self.replace_empty_block(self.block_tables[batch_idx])
|
|
@@ -169,20 +177,23 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
169
177
|
dec_attn_mask: torch.Tensor,
|
|
170
178
|
page_table_manager: RBLNPageTableManager,
|
|
171
179
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
172
|
-
|
|
180
|
+
config: "PreTrainedConfig" = None,
|
|
181
|
+
logits_last_dim: Optional[int] = None,
|
|
173
182
|
**kwargs: Any,
|
|
174
183
|
) -> None:
|
|
175
184
|
super().__init__(runtime, **kwargs)
|
|
176
185
|
self.phase = phase
|
|
177
186
|
self.batch_size = batch_size
|
|
178
187
|
self.rbln_config = rbln_config
|
|
188
|
+
self.config = config
|
|
189
|
+
self.logits_last_dim = logits_last_dim
|
|
179
190
|
|
|
180
191
|
# shared resources between prefill and decode phase
|
|
181
192
|
self.dec_attn_mask = dec_attn_mask
|
|
182
193
|
self.page_table_manager = page_table_manager
|
|
194
|
+
self.out_buffers = None
|
|
183
195
|
|
|
184
196
|
if self.phase == "prefill":
|
|
185
|
-
self.out_buffers = out_buffers
|
|
186
197
|
self.causal_mask = 1 - torch.triu(
|
|
187
198
|
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
188
199
|
)
|
|
@@ -276,28 +287,48 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
276
287
|
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
277
288
|
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
278
289
|
|
|
279
|
-
|
|
290
|
+
batch_size = inputs.shape[0]
|
|
291
|
+
if batch_size != self.batch_size:
|
|
280
292
|
raise RuntimeError(
|
|
281
|
-
f"
|
|
293
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if batch_size != cache_position.shape[0]:
|
|
297
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
298
|
+
|
|
299
|
+
if self.rbln_config.use_local_attention:
|
|
300
|
+
local_block_tables = (
|
|
301
|
+
local_block_tables
|
|
302
|
+
if local_block_tables is not None
|
|
303
|
+
else torch.arange(0, batch_size, dtype=torch.int16).view(batch_size, -1)
|
|
282
304
|
)
|
|
283
305
|
|
|
284
306
|
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
285
|
-
for b_idx in range(
|
|
307
|
+
for b_idx in range(batch_size):
|
|
286
308
|
decoding_step = cache_position[b_idx].item()
|
|
287
309
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
288
310
|
raise ValueError(
|
|
289
311
|
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
290
312
|
)
|
|
291
313
|
|
|
292
|
-
if
|
|
293
|
-
self.dec_attn_mask[b_idx]
|
|
294
|
-
|
|
314
|
+
if self.rbln_config.use_position_ids:
|
|
315
|
+
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
316
|
+
|
|
317
|
+
if self.batch_size < block_tables.shape[0]:
|
|
318
|
+
block_tables = block_tables[: self.batch_size]
|
|
319
|
+
|
|
320
|
+
if self.dec_attn_mask is not None and self.batch_size < self.dec_attn_mask.shape[0]:
|
|
321
|
+
self.dec_attn_mask = self.dec_attn_mask[: self.batch_size]
|
|
295
322
|
else:
|
|
296
|
-
|
|
323
|
+
if is_external_block_tables:
|
|
324
|
+
self.dec_attn_mask[b_idx].fill_(0)
|
|
325
|
+
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
326
|
+
else:
|
|
327
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
297
328
|
|
|
298
329
|
attention_mask = self.dec_attn_mask
|
|
299
330
|
|
|
300
|
-
|
|
331
|
+
outputs = super().forward(
|
|
301
332
|
inputs,
|
|
302
333
|
cache_position,
|
|
303
334
|
block_tables,
|
|
@@ -306,15 +337,20 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
306
337
|
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
307
338
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
308
339
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
340
|
+
out=self.out_buffers,
|
|
309
341
|
)
|
|
310
342
|
|
|
311
|
-
|
|
343
|
+
if self.rbln_config.output_hidden_states:
|
|
344
|
+
return RBLNDecoderOnlyOutput(logits=outputs[0], hidden_states=tuple(outputs[1:]))
|
|
345
|
+
else:
|
|
346
|
+
return RBLNDecoderOnlyOutput(logits=outputs, hidden_states=None)
|
|
312
347
|
|
|
313
348
|
def _prepare_prefill_inputs(
|
|
314
349
|
self,
|
|
315
350
|
inputs: torch.Tensor,
|
|
316
351
|
cache_position: Optional[torch.Tensor] = None,
|
|
317
352
|
attention_mask: Optional[torch.Tensor] = None,
|
|
353
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
318
354
|
position_embed: Optional[torch.Tensor] = None,
|
|
319
355
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
320
356
|
):
|
|
@@ -324,9 +360,27 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
324
360
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
325
361
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
326
362
|
if attention_mask is not None:
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
363
|
+
if attention_mask.dim() != 1:
|
|
364
|
+
raise ValueError("attention_mask must be a 1D tensor.")
|
|
365
|
+
|
|
366
|
+
mask_bool = attention_mask.to(dtype=torch.bool)
|
|
367
|
+
if (~mask_bool).any():
|
|
368
|
+
indice_one = torch.nonzero(mask_bool, as_tuple=False)
|
|
369
|
+
if indice_one.numel() == 0:
|
|
370
|
+
raise ValueError("attention_mask with padding must include at least one real token.")
|
|
371
|
+
first_one_idx, last_one_idx = int(indice_one[0].item()), int(indice_one[-1].item())
|
|
372
|
+
if last_one_idx - first_one_idx + 1 != mask_bool.sum():
|
|
373
|
+
raise ValueError(
|
|
374
|
+
"attention_mask must group all 1s together (e.g. 000111 or 1111000). "
|
|
375
|
+
"Zeros between real tokens like 101010 are not supported."
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if self.rbln_config.can_generate and not mask_bool[first_one_idx:].all():
|
|
379
|
+
raise ValueError("attention_mask must be left padded for generation.")
|
|
380
|
+
|
|
381
|
+
inputs = inputs[:, mask_bool]
|
|
382
|
+
position_embed = None if position_embed is None else position_embed[:, :, :, mask_bool, :]
|
|
383
|
+
token_type_ids = None if token_type_ids is None else token_type_ids[:, mask_bool]
|
|
330
384
|
|
|
331
385
|
query_length = inputs.shape[1]
|
|
332
386
|
if query_length > self.rbln_config.max_seq_len:
|
|
@@ -335,17 +389,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
335
389
|
)
|
|
336
390
|
|
|
337
391
|
# Initialize attention mask for chunked processing
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
392
|
+
if self.rbln_config.use_attention_mask:
|
|
393
|
+
if self.rbln_config.use_position_ids:
|
|
394
|
+
chunked_attention_mask = torch.zeros(
|
|
395
|
+
1, self.rbln_config.max_seq_len, dtype=self.rbln_config.torch_dtype
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
chunked_attention_mask = torch.zeros(
|
|
399
|
+
1,
|
|
400
|
+
1,
|
|
401
|
+
self.rbln_config.prefill_chunk_size,
|
|
402
|
+
self.rbln_config.max_seq_len,
|
|
403
|
+
dtype=self.rbln_config.torch_dtype,
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
chunked_attention_mask = None
|
|
349
407
|
|
|
350
408
|
cache_position = (
|
|
351
409
|
torch.arange(query_length, dtype=torch.int32).unsqueeze(0) if cache_position is None else cache_position
|
|
@@ -363,7 +421,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
363
421
|
cache_position = F.pad(cache_position, (0, padding_size))
|
|
364
422
|
|
|
365
423
|
# Overwrite position_ids and padded_cache_lengths
|
|
366
|
-
|
|
424
|
+
if self.rbln_config.use_position_ids and position_ids is None:
|
|
425
|
+
position_ids = cache_position.clone()
|
|
426
|
+
else:
|
|
427
|
+
position_ids = position_ids
|
|
428
|
+
|
|
367
429
|
padded_cache_lengths = 0
|
|
368
430
|
|
|
369
431
|
return (
|
|
@@ -377,6 +439,68 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
377
439
|
token_type_ids,
|
|
378
440
|
)
|
|
379
441
|
|
|
442
|
+
def _prepare_prefill_outputs(
|
|
443
|
+
self,
|
|
444
|
+
query_length: int,
|
|
445
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
446
|
+
):
|
|
447
|
+
# Prepare out buffers
|
|
448
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
449
|
+
padded_input_length = query_length + padding_size
|
|
450
|
+
padded_mask_length = (
|
|
451
|
+
attention_mask.shape[-1] + padding_size if attention_mask is not None else padded_input_length
|
|
452
|
+
)
|
|
453
|
+
out_buffers = [[] for _ in range(padded_input_length // self.rbln_config.prefill_chunk_size)]
|
|
454
|
+
|
|
455
|
+
valid_start_index = (
|
|
456
|
+
int(torch.nonzero(attention_mask, as_tuple=False)[0][0].item()) if attention_mask is not None else 0
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
if self.logits_last_dim is None:
|
|
460
|
+
logits_last_dim = self.config.vocab_size if self.rbln_config.can_generate else self.config.hidden_size
|
|
461
|
+
else:
|
|
462
|
+
logits_last_dim = self.logits_last_dim
|
|
463
|
+
|
|
464
|
+
# Prepare logits buffer
|
|
465
|
+
logits_size = (
|
|
466
|
+
1,
|
|
467
|
+
1 if self.rbln_config.logits_to_keep == 1 else padded_mask_length,
|
|
468
|
+
logits_last_dim,
|
|
469
|
+
)
|
|
470
|
+
output_logits = torch.full(logits_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
|
|
471
|
+
|
|
472
|
+
if self.rbln_config.logits_to_keep == 1:
|
|
473
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
474
|
+
out_buffers[i].append(output_logits)
|
|
475
|
+
else:
|
|
476
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
477
|
+
s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
|
|
478
|
+
out_buffers[i].append(output_logits[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size])
|
|
479
|
+
|
|
480
|
+
# Prepare output hidden states
|
|
481
|
+
output_hidden_states = None
|
|
482
|
+
if self.rbln_config.output_hidden_states:
|
|
483
|
+
hidden_states_size = (
|
|
484
|
+
1,
|
|
485
|
+
padded_mask_length,
|
|
486
|
+
self.config.hidden_size,
|
|
487
|
+
)
|
|
488
|
+
output_hidden_states = [
|
|
489
|
+
torch.full(hidden_states_size, fill_value=1e-10, dtype=self.rbln_config.torch_dtype)
|
|
490
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
491
|
+
]
|
|
492
|
+
|
|
493
|
+
for i in range(padded_input_length // self.rbln_config.prefill_chunk_size):
|
|
494
|
+
s_idx = i * self.rbln_config.prefill_chunk_size + valid_start_index
|
|
495
|
+
out_buffers[i].extend(
|
|
496
|
+
[
|
|
497
|
+
hidden_states_buffer[:, s_idx : s_idx + self.rbln_config.prefill_chunk_size]
|
|
498
|
+
for hidden_states_buffer in output_hidden_states
|
|
499
|
+
]
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
return out_buffers, output_logits, output_hidden_states
|
|
503
|
+
|
|
380
504
|
def prefill_forward(
|
|
381
505
|
self,
|
|
382
506
|
inputs: torch.Tensor,
|
|
@@ -385,6 +509,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
385
509
|
batch_idx: Optional[int] = None,
|
|
386
510
|
block_tables: Optional[torch.Tensor] = None,
|
|
387
511
|
is_external_block_tables: Optional[bool] = None,
|
|
512
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
388
513
|
position_embed: Optional[torch.Tensor] = None,
|
|
389
514
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
390
515
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
@@ -417,9 +542,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
417
542
|
query_length,
|
|
418
543
|
token_type_ids,
|
|
419
544
|
) = self._prepare_prefill_inputs(
|
|
420
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
545
|
+
inputs, cache_position, attention_mask, position_ids, position_embed, token_type_ids=token_type_ids
|
|
421
546
|
)
|
|
422
547
|
|
|
548
|
+
out_buffers, output_logits, output_hidden_states = self._prepare_prefill_outputs(query_length, attention_mask)
|
|
549
|
+
|
|
423
550
|
# Assumed that prefix caching was performed externally if cache_position doesn't start from 0.
|
|
424
551
|
prefix_cached_len = cache_position[0][0].item()
|
|
425
552
|
if prefix_cached_len > 0:
|
|
@@ -428,11 +555,13 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
428
555
|
"Prefix Caching is not supported yet for non-multiple of prefill_chunk_size."
|
|
429
556
|
)
|
|
430
557
|
if self.rbln_config.use_attention_mask:
|
|
431
|
-
|
|
558
|
+
if self.rbln_config.use_position_ids:
|
|
559
|
+
chunked_attention_mask[:, :prefix_cached_len] = 1
|
|
560
|
+
else:
|
|
561
|
+
chunked_attention_mask[:, :, :, :prefix_cached_len] = 1
|
|
432
562
|
|
|
433
563
|
# Process input in chunks of size `prefill_chunk_size`
|
|
434
|
-
|
|
435
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
564
|
+
for i, step in enumerate(range(0, query_length, self.rbln_config.prefill_chunk_size)):
|
|
436
565
|
s, e = step, step + self.rbln_config.prefill_chunk_size
|
|
437
566
|
# Extract the current chunk of inputs, cache positions, position ids, and position embeddings
|
|
438
567
|
input_chunk = inputs[:, s:e]
|
|
@@ -441,17 +570,29 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
441
570
|
position_embed_chunk = position_embed[:, :, :, s:e, :] if position_embed is not None else None
|
|
442
571
|
|
|
443
572
|
# Update attention mask to ensure proper causal behavior
|
|
444
|
-
if self.rbln_config.use_attention_mask
|
|
445
|
-
if
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
573
|
+
if self.rbln_config.use_attention_mask:
|
|
574
|
+
if self.rbln_config.use_position_ids:
|
|
575
|
+
if step > 0: # update previous chunk
|
|
576
|
+
# Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
|
|
577
|
+
prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
|
|
578
|
+
prev_chunk_end = s + prefix_cached_len
|
|
579
|
+
chunked_attention_mask[:, prev_chunk_start:prev_chunk_end] = 1
|
|
580
|
+
|
|
581
|
+
current_chunk_start = s + prefix_cached_len
|
|
582
|
+
current_chunk_end = min(e, query_length) + prefix_cached_len
|
|
583
|
+
if current_chunk_end > current_chunk_start:
|
|
584
|
+
chunked_attention_mask[:, current_chunk_start:current_chunk_end] = 1
|
|
585
|
+
|
|
586
|
+
else:
|
|
587
|
+
if step > 0: # update previous chunk
|
|
588
|
+
# Update attention mask for the previous chunk (from s - prefill_chunk_size to s)
|
|
589
|
+
prev_chunk_start = s - self.rbln_config.prefill_chunk_size + prefix_cached_len
|
|
590
|
+
prev_chunk_end = s + prefix_cached_len
|
|
591
|
+
chunked_attention_mask[:, :, :, prev_chunk_start:prev_chunk_end] = 1
|
|
592
|
+
|
|
593
|
+
current_chunk_start = s + prefix_cached_len
|
|
594
|
+
current_chunk_end = e + prefix_cached_len
|
|
595
|
+
chunked_attention_mask[:, :, :, current_chunk_start:current_chunk_end] = self.causal_mask
|
|
455
596
|
|
|
456
597
|
# Calculate query position if needed
|
|
457
598
|
if self.rbln_config.use_local_attention or self.rbln_config.logits_to_keep > 0:
|
|
@@ -464,7 +605,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
464
605
|
query_position = None
|
|
465
606
|
|
|
466
607
|
# Forward pass for the current chunk
|
|
467
|
-
|
|
608
|
+
_ = super().forward(
|
|
468
609
|
input_chunk,
|
|
469
610
|
cache_pos_chunk,
|
|
470
611
|
block_tables,
|
|
@@ -474,31 +615,33 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
474
615
|
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
475
616
|
position_ids_chunk,
|
|
476
617
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
477
|
-
out=
|
|
618
|
+
out=out_buffers[i],
|
|
478
619
|
)
|
|
479
|
-
output_logits.append(output_logit)
|
|
480
620
|
|
|
481
621
|
# Aggregate output_logits
|
|
482
|
-
|
|
483
|
-
if self.rbln_config.logits_to_keep
|
|
484
|
-
output_logits = output_logits
|
|
622
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
623
|
+
if self.rbln_config.logits_to_keep == 1:
|
|
624
|
+
output_logits = output_logits
|
|
625
|
+
elif self.rbln_config.logits_to_keep > 1:
|
|
626
|
+
output_logits = output_logits[:, -padding_size - self.rbln_config.logits_to_keep : -padding_size, :]
|
|
485
627
|
else:
|
|
486
|
-
output_logits = output_logits[:,
|
|
487
|
-
# index copy for masked output_logits
|
|
488
|
-
if attention_mask is not None:
|
|
489
|
-
new_output_logits = torch.full(
|
|
490
|
-
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
491
|
-
fill_value=1e-10,
|
|
492
|
-
dtype=output_logits.dtype,
|
|
493
|
-
)
|
|
494
|
-
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
495
|
-
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
628
|
+
output_logits = output_logits[:, :-padding_size, :]
|
|
496
629
|
|
|
497
|
-
|
|
630
|
+
all_hidden_states = None
|
|
631
|
+
if self.rbln_config.output_hidden_states:
|
|
632
|
+
all_hidden_states = [
|
|
633
|
+
output_hidden_state[:, :-padding_size, :] for output_hidden_state in output_hidden_states
|
|
634
|
+
]
|
|
635
|
+
all_hidden_states = tuple(all_hidden_states)
|
|
498
636
|
|
|
499
637
|
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
500
638
|
if self.rbln_config.can_generate and not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
501
|
-
self.
|
|
502
|
-
|
|
639
|
+
if self.rbln_config.use_position_ids:
|
|
640
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
641
|
+
else:
|
|
642
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
|
643
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
503
644
|
|
|
504
|
-
return RBLNDecoderOnlyOutput(
|
|
645
|
+
return RBLNDecoderOnlyOutput(
|
|
646
|
+
logits=output_logits, padded_cache_lengths=padded_cache_lengths, hidden_states=all_hidden_states
|
|
647
|
+
)
|
|
@@ -12,10 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
|
+
from transformers import GenerationConfig
|
|
18
19
|
from transformers.generation.utils import GenerationMixin
|
|
20
|
+
from transformers.modeling_outputs import ModelOutput
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
if TYPE_CHECKING:
|
|
@@ -91,20 +93,26 @@ class RBLNDecoderOnlyGenerationMixin(GenerationMixin):
|
|
|
91
93
|
self,
|
|
92
94
|
input_ids: torch.LongTensor,
|
|
93
95
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
94
|
-
|
|
96
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
95
97
|
**kwargs,
|
|
96
|
-
):
|
|
98
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
97
99
|
"""
|
|
98
100
|
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
101
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
99
102
|
|
|
100
103
|
Args:
|
|
101
|
-
input_ids: The input ids to the model.
|
|
102
|
-
attention_mask: The attention mask to the model.
|
|
103
|
-
|
|
104
|
-
|
|
104
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
105
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
106
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
107
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
108
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
109
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A ModelOutput (if return_dict_in_generate=True or when config.return_dict_in_generate=True) or a torch.LongTensor.
|
|
105
113
|
"""
|
|
106
|
-
if
|
|
107
|
-
kwargs["
|
|
114
|
+
if generation_config is not None:
|
|
115
|
+
kwargs["generation_config"] = generation_config
|
|
108
116
|
if attention_mask is not None:
|
|
109
117
|
kwargs["attention_mask"] = attention_mask
|
|
110
118
|
|
|
@@ -142,7 +142,7 @@ class LoRALinear(nn.Module):
|
|
|
142
142
|
padded_lora_a = []
|
|
143
143
|
padded_lora_b = []
|
|
144
144
|
|
|
145
|
-
for
|
|
145
|
+
for _, (lora_a, lora_b) in enumerate(zip(lora_a_weights, lora_b_weights)):
|
|
146
146
|
current_rank = lora_a.shape[0]
|
|
147
147
|
if current_rank < max_rank:
|
|
148
148
|
# Pad with zeros
|