optimum-rbln 0.2.1a3__py3-none-any.whl → 0.2.1a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/modeling_base.py +10 -9
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +3 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +200 -154
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -7
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +59 -37
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +1 -1
- {optimum_rbln-0.2.1a3.dist-info → optimum_rbln-0.2.1a5.dist-info}/METADATA +1 -1
- {optimum_rbln-0.2.1a3.dist-info → optimum_rbln-0.2.1a5.dist-info}/RECORD +11 -11
- {optimum_rbln-0.2.1a3.dist-info → optimum_rbln-0.2.1a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.2.1a3.dist-info → optimum_rbln-0.2.1a5.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__version__.py
CHANGED
optimum/rbln/modeling_base.py
CHANGED
@@ -442,8 +442,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
442
442
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
443
443
|
return
|
444
444
|
|
445
|
-
|
446
|
-
|
445
|
+
# Normalize paths to handle relative paths and symlinks
|
446
|
+
real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
|
447
|
+
save_directory_path = Path(save_directory).resolve()
|
447
448
|
|
448
449
|
if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
|
449
450
|
raise FileNotFoundError(
|
@@ -452,13 +453,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
452
453
|
f"Please ensure the model directory exists and you have the necessary permissions to access it."
|
453
454
|
)
|
454
455
|
|
455
|
-
if save_directory_path
|
456
|
+
if save_directory_path == real_save_dir:
|
456
457
|
raise FileExistsError(
|
457
458
|
f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
|
458
459
|
)
|
459
460
|
|
460
|
-
# Create a temporary directory
|
461
|
-
tmp_dir =
|
461
|
+
# Create a temporary directory with normalized path
|
462
|
+
tmp_dir = str(save_directory_path) + ".tmp"
|
462
463
|
try:
|
463
464
|
# Remove temporary directory if it exists from a previous failed attempt
|
464
465
|
if os.path.exists(tmp_dir):
|
@@ -473,9 +474,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
473
474
|
self.generation_config.save_pretrained(tmp_dir)
|
474
475
|
|
475
476
|
# If everything succeeded, atomically replace the target directory
|
476
|
-
if os.path.exists(
|
477
|
-
shutil.rmtree(
|
478
|
-
os.rename(tmp_dir,
|
477
|
+
if os.path.exists(save_directory_path):
|
478
|
+
shutil.rmtree(save_directory_path)
|
479
|
+
os.rename(tmp_dir, save_directory_path)
|
479
480
|
|
480
481
|
except Exception as e:
|
481
482
|
# Clean up the temporary directory if anything fails
|
@@ -484,7 +485,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
484
485
|
raise e # Re-raise the exception after cleanup
|
485
486
|
|
486
487
|
if push_to_hub:
|
487
|
-
return super().push_to_hub(
|
488
|
+
return super().push_to_hub(str(save_directory_path), **kwargs)
|
488
489
|
|
489
490
|
@staticmethod
|
490
491
|
def _raise_missing_compiled_file_error(missing_files: List[str]):
|
@@ -427,12 +427,14 @@ class DecoderOnlyModel(nn.Module):
|
|
427
427
|
cos, sin = None, None
|
428
428
|
|
429
429
|
# (batch, seq_len) -> (batch,)
|
430
|
-
seq_positions = cache_position[:, 0]
|
431
430
|
if self.attn_impl == "flash_attn":
|
431
|
+
seq_positions = cache_position[:, 0]
|
432
432
|
max_seq_len = past_key_values[0][0].shape[-2]
|
433
433
|
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
434
434
|
seq_positions=seq_positions, max_seq_len=max_seq_len
|
435
435
|
)
|
436
|
+
else:
|
437
|
+
seq_positions = cache_position[:, :1]
|
436
438
|
|
437
439
|
present_key_values = past_key_values
|
438
440
|
for layer in self.layers:
|
@@ -38,34 +38,188 @@ from .decoderonly_architecture import (
|
|
38
38
|
logger = get_logger()
|
39
39
|
|
40
40
|
if TYPE_CHECKING:
|
41
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
41
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
42
42
|
|
43
43
|
|
44
44
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
45
45
|
mandatory_members = ["main_input_name", "embed_tokens"]
|
46
46
|
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
runtime: rebel.Runtime,
|
50
|
+
phase: str,
|
51
|
+
batch_size: int,
|
52
|
+
dec_attn_mask: torch.Tensor,
|
53
|
+
**kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
super().__init__(runtime, **kwargs)
|
56
|
+
self.phase = phase
|
57
|
+
self.batch_size = batch_size
|
58
|
+
|
59
|
+
# shared tensor between prefill and decode phase
|
60
|
+
self.dec_attn_mask = dec_attn_mask
|
61
|
+
|
62
|
+
if self.phase == "prefill":
|
63
|
+
vocab_size = kwargs.pop("vocab_size")
|
64
|
+
self.max_seq_len = kwargs.pop("max_seq_len")
|
65
|
+
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
|
66
|
+
self.output_size = [1, 1, vocab_size]
|
67
|
+
self.causal_mask = 1 - torch.triu(
|
68
|
+
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
69
|
+
)
|
70
|
+
|
47
71
|
def forward(
|
48
72
|
self,
|
49
|
-
input_ids: torch.LongTensor,
|
50
|
-
inputs_embeds: torch.Tensor,
|
51
|
-
|
52
|
-
|
53
|
-
|
73
|
+
input_ids: Optional[torch.LongTensor] = None,
|
74
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
75
|
+
cache_position: torch.Tensor = None,
|
76
|
+
attention_mask: Optional[torch.Tensor] = None,
|
77
|
+
batch_idx: Optional[int] = None,
|
54
78
|
):
|
79
|
+
if input_ids is None and inputs_embeds is None:
|
80
|
+
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
81
|
+
|
55
82
|
if inputs_embeds is None:
|
56
|
-
|
83
|
+
inputs = input_ids
|
57
84
|
if self.embed_tokens is not None:
|
58
|
-
|
85
|
+
inputs = self.embed_tokens(inputs)
|
59
86
|
else:
|
60
|
-
|
87
|
+
inputs = inputs_embeds
|
61
88
|
|
62
|
-
|
63
|
-
|
64
|
-
|
89
|
+
if self.phase == "decode":
|
90
|
+
return self.decode_forward(
|
91
|
+
inputs,
|
92
|
+
cache_position,
|
93
|
+
attention_mask=attention_mask,
|
94
|
+
)
|
95
|
+
else:
|
96
|
+
return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
|
97
|
+
|
98
|
+
def decode_forward(
|
99
|
+
self,
|
100
|
+
inputs: torch.Tensor,
|
101
|
+
cache_position: torch.Tensor = None,
|
102
|
+
attention_mask: Optional[torch.Tensor] = None,
|
103
|
+
) -> torch.FloatTensor:
|
104
|
+
batch_size = inputs.shape[0]
|
105
|
+
if batch_size != self.batch_size:
|
106
|
+
raise RuntimeError(
|
107
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
108
|
+
)
|
109
|
+
|
110
|
+
if batch_size != cache_position.shape[0]:
|
111
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
112
|
+
|
113
|
+
if attention_mask is None:
|
114
|
+
for b_idx in range(batch_size):
|
115
|
+
decoding_step = cache_position[b_idx].item()
|
116
|
+
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
117
|
+
raise ValueError(
|
118
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
119
|
+
)
|
120
|
+
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
121
|
+
|
122
|
+
logits = super().forward(
|
123
|
+
inputs,
|
124
|
+
self.dec_attn_mask if attention_mask is None else attention_mask,
|
65
125
|
cache_position,
|
66
|
-
**kwargs,
|
67
126
|
)
|
68
127
|
|
128
|
+
return logits
|
129
|
+
|
130
|
+
def prefill_forward(
|
131
|
+
self,
|
132
|
+
inputs: torch.Tensor,
|
133
|
+
cache_position: torch.Tensor = None,
|
134
|
+
attention_mask: Optional[torch.Tensor] = None,
|
135
|
+
batch_idx: int = None,
|
136
|
+
) -> torch.FloatTensor:
|
137
|
+
"""
|
138
|
+
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
139
|
+
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
140
|
+
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
141
|
+
"""
|
142
|
+
|
143
|
+
if batch_idx is None or batch_idx >= self.batch_size:
|
144
|
+
raise RuntimeError(
|
145
|
+
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
146
|
+
)
|
147
|
+
|
148
|
+
# Handle continuous batching in a compiled graph by extracting valid inputs
|
149
|
+
# If an attention mask is provided, select only the valid (non-masked) inputs
|
150
|
+
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
151
|
+
|
152
|
+
query_length = inputs.shape[1]
|
153
|
+
if query_length > self.max_seq_len:
|
154
|
+
raise ValueError(
|
155
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
156
|
+
)
|
157
|
+
|
158
|
+
# Initialize attention mask for chunked processing
|
159
|
+
chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
160
|
+
|
161
|
+
# Buffer for storing output logits
|
162
|
+
out_buffers = [
|
163
|
+
torch.empty(
|
164
|
+
size=self.output_size,
|
165
|
+
dtype=torch.float32,
|
166
|
+
device="cpu",
|
167
|
+
)
|
168
|
+
]
|
169
|
+
|
170
|
+
# Process input in chunks of size `prefill_chunk_size`
|
171
|
+
for step in range(0, query_length, self.prefill_chunk_size):
|
172
|
+
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
173
|
+
if (step + self.prefill_chunk_size) > query_length:
|
174
|
+
padding_size = step + self.prefill_chunk_size - query_length
|
175
|
+
# inputs_embeds
|
176
|
+
if inputs.dim() == 3:
|
177
|
+
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
178
|
+
# inputs_ids
|
179
|
+
else:
|
180
|
+
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
181
|
+
|
182
|
+
cache_position = torch.cat(
|
183
|
+
[
|
184
|
+
cache_position,
|
185
|
+
torch.arange(
|
186
|
+
query_length,
|
187
|
+
step + self.prefill_chunk_size,
|
188
|
+
dtype=torch.int32,
|
189
|
+
).unsqueeze(0),
|
190
|
+
],
|
191
|
+
dim=-1,
|
192
|
+
)
|
193
|
+
|
194
|
+
# Extract the current chunk of inputs and cache positions
|
195
|
+
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
196
|
+
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
197
|
+
|
198
|
+
# Update attention mask to ensure proper causal behavior
|
199
|
+
if step >= self.prefill_chunk_size:
|
200
|
+
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
201
|
+
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
202
|
+
|
203
|
+
# Define batch position and query position
|
204
|
+
batch_position = torch.tensor(batch_idx, dtype=torch.int16)
|
205
|
+
query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
|
206
|
+
|
207
|
+
# Forward pass for the current chunk
|
208
|
+
logits = super().forward(
|
209
|
+
input_chunk,
|
210
|
+
chunked_attention_mask,
|
211
|
+
cache_pos_chunk,
|
212
|
+
batch_position,
|
213
|
+
query_position,
|
214
|
+
out=out_buffers,
|
215
|
+
)
|
216
|
+
|
217
|
+
# Update decoder attention mask with processed KV-cache length from prefill phase
|
218
|
+
self.dec_attn_mask[batch_idx].fill_(0)
|
219
|
+
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
220
|
+
|
221
|
+
return logits
|
222
|
+
|
69
223
|
|
70
224
|
@dataclass
|
71
225
|
class RBLNDecoderOnlyOutput(ModelOutput):
|
@@ -103,13 +257,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
103
257
|
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
104
258
|
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
105
259
|
|
106
|
-
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
107
|
-
self.causal_mask = 1 - torch.triu(
|
108
|
-
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
109
|
-
)
|
110
|
-
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
|
111
|
-
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
112
|
-
|
113
260
|
main_input_name = self.main_input_name
|
114
261
|
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
115
262
|
main_input_name = "inputs_embeds"
|
@@ -124,11 +271,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
124
271
|
else:
|
125
272
|
self.embed_tokens = None
|
126
273
|
|
274
|
+
dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
127
275
|
self.prefill_decoder = RBLNRuntimeModel(
|
128
|
-
runtime=self.model[0],
|
276
|
+
runtime=self.model[0],
|
277
|
+
main_input_name=main_input_name,
|
278
|
+
embed_tokens=self.embed_tokens,
|
279
|
+
phase="prefill",
|
280
|
+
batch_size=self.batch_size,
|
281
|
+
dec_attn_mask=dec_attn_mask,
|
282
|
+
vocab_size=self.config.vocab_size,
|
283
|
+
max_seq_len=self.max_seq_len,
|
284
|
+
prefill_chunk_size=self.prefill_chunk_size,
|
129
285
|
)
|
130
286
|
self.decoder = RBLNRuntimeModel(
|
131
|
-
runtime=self.model[1],
|
287
|
+
runtime=self.model[1],
|
288
|
+
main_input_name=main_input_name,
|
289
|
+
embed_tokens=self.embed_tokens,
|
290
|
+
phase="decode",
|
291
|
+
batch_size=self.batch_size,
|
292
|
+
dec_attn_mask=dec_attn_mask,
|
132
293
|
)
|
133
294
|
|
134
295
|
@classmethod
|
@@ -155,7 +316,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
155
316
|
def get_quantized_model(
|
156
317
|
cls,
|
157
318
|
model_id: str,
|
158
|
-
config: Optional[PretrainedConfig] = None,
|
319
|
+
config: Optional["PretrainedConfig"] = None,
|
159
320
|
use_auth_token: Optional[Union[bool, str]] = None,
|
160
321
|
revision: Optional[str] = None,
|
161
322
|
force_download: bool = False,
|
@@ -496,32 +657,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
496
657
|
generate_idx: Optional[torch.Tensor] = None,
|
497
658
|
**kwargs,
|
498
659
|
) -> Tuple[torch.FloatTensor]:
|
499
|
-
|
660
|
+
"""
|
661
|
+
Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
662
|
+
For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
663
|
+
A for-loop ensures synchronization with the HuggingFace generate API.
|
664
|
+
The decoder stage operates as usual, processing inputs in batch mode.
|
665
|
+
"""
|
666
|
+
# Prefll
|
500
667
|
if cache_position is None:
|
501
668
|
logits = []
|
502
|
-
|
503
|
-
batch_size =
|
669
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
670
|
+
batch_size = inputs.shape[0]
|
504
671
|
|
505
672
|
for b_idx in range(batch_size):
|
506
|
-
# Transform inputs as vllm format
|
507
|
-
if attention_mask is not None:
|
508
|
-
input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
509
|
-
else:
|
510
|
-
input_tensor = input_tensors[b_idx : b_idx + 1]
|
511
|
-
|
512
673
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
674
|
+
logit = self.prefill_decoder(
|
675
|
+
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
676
|
+
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
677
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
517
678
|
cache_position=cache_position,
|
518
679
|
batch_idx=b_idx,
|
519
680
|
)
|
520
681
|
logits.append(logit)
|
682
|
+
|
521
683
|
logits = torch.cat(logits, dim=0)
|
522
|
-
#
|
684
|
+
# Decoder
|
523
685
|
else:
|
524
|
-
logits = self.
|
686
|
+
logits = self.decoder(
|
525
687
|
input_ids=input_ids,
|
526
688
|
inputs_embeds=inputs_embeds,
|
527
689
|
cache_position=cache_position,
|
@@ -531,119 +693,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
531
693
|
logits=logits,
|
532
694
|
generate_idx=generate_idx,
|
533
695
|
)
|
534
|
-
|
535
|
-
def _forward_prefill(
|
536
|
-
self,
|
537
|
-
input_ids: torch.LongTensor = None,
|
538
|
-
inputs_embeds: torch.Tensor = None,
|
539
|
-
cache_position: torch.Tensor = None,
|
540
|
-
batch_idx: int = None,
|
541
|
-
) -> torch.FloatTensor:
|
542
|
-
if batch_idx is None or batch_idx >= self.batch_size:
|
543
|
-
raise RuntimeError(
|
544
|
-
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
545
|
-
)
|
546
|
-
|
547
|
-
out_buffers = [
|
548
|
-
torch.empty(
|
549
|
-
size=[
|
550
|
-
1,
|
551
|
-
1,
|
552
|
-
self.config.vocab_size,
|
553
|
-
],
|
554
|
-
dtype=torch.float32,
|
555
|
-
device="cpu",
|
556
|
-
)
|
557
|
-
]
|
558
|
-
|
559
|
-
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
560
|
-
query_length = input_tensors.shape[1]
|
561
|
-
if query_length > self.max_seq_len:
|
562
|
-
raise ValueError(
|
563
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
564
|
-
)
|
565
|
-
|
566
|
-
_attention_mask = self.prefill_attention_mask.clone()
|
567
|
-
|
568
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
569
|
-
# pad input_tensors & cache_position for prefill_chunk
|
570
|
-
if (step + self.prefill_chunk_size) > query_length:
|
571
|
-
pad_to_chunk = step + self.prefill_chunk_size - query_length
|
572
|
-
if inputs_embeds is not None:
|
573
|
-
input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
|
574
|
-
else:
|
575
|
-
input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
|
576
|
-
|
577
|
-
cache_position = torch.cat(
|
578
|
-
[
|
579
|
-
cache_position,
|
580
|
-
torch.arange(
|
581
|
-
query_length,
|
582
|
-
step + self.prefill_chunk_size,
|
583
|
-
dtype=torch.int32,
|
584
|
-
).unsqueeze(0),
|
585
|
-
],
|
586
|
-
dim=-1,
|
587
|
-
)
|
588
|
-
|
589
|
-
# slice input_tensor & cache_position with prefill_chunk_size
|
590
|
-
_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
|
591
|
-
_cache_position = cache_position[:, step : step + self.prefill_chunk_size]
|
592
|
-
|
593
|
-
# update attention_mask
|
594
|
-
if step >= self.prefill_chunk_size:
|
595
|
-
_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
596
|
-
_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
597
|
-
|
598
|
-
query_position = (query_length - 1) % self.prefill_chunk_size
|
599
|
-
|
600
|
-
logits = self.prefill_decoder(
|
601
|
-
input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
|
602
|
-
inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
|
603
|
-
attention_mask=_attention_mask.contiguous(),
|
604
|
-
cache_position=_cache_position.contiguous(),
|
605
|
-
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
606
|
-
query_position=torch.tensor(query_position, dtype=torch.int16),
|
607
|
-
out=out_buffers,
|
608
|
-
)
|
609
|
-
|
610
|
-
# update decoder_attn_mask with preprocessed kv-cache length in prefill phase
|
611
|
-
self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
|
612
|
-
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
613
|
-
|
614
|
-
return logits
|
615
|
-
|
616
|
-
def _forward_decoder(
|
617
|
-
self,
|
618
|
-
input_ids: torch.LongTensor = None,
|
619
|
-
inputs_embeds: torch.Tensor = None,
|
620
|
-
cache_position: torch.Tensor = None,
|
621
|
-
) -> torch.FloatTensor:
|
622
|
-
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
623
|
-
if input_tensors is None:
|
624
|
-
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
625
|
-
|
626
|
-
batch_size = input_tensors.shape[0]
|
627
|
-
if batch_size != self.batch_size:
|
628
|
-
raise RuntimeError(
|
629
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
630
|
-
)
|
631
|
-
|
632
|
-
if batch_size != cache_position.shape[0]:
|
633
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
634
|
-
|
635
|
-
for b_idx in range(batch_size):
|
636
|
-
decoding_step = cache_position[b_idx].item()
|
637
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
638
|
-
raise ValueError(
|
639
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
640
|
-
)
|
641
|
-
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
642
|
-
logits = self.decoder(
|
643
|
-
input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
|
644
|
-
inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
|
645
|
-
attention_mask=self.dec_attn_mask.contiguous(),
|
646
|
-
cache_position=cache_position.contiguous(),
|
647
|
-
)
|
648
|
-
|
649
|
-
return logits
|
@@ -25,7 +25,6 @@ from transformers import (
|
|
25
25
|
PreTrainedModel,
|
26
26
|
)
|
27
27
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
28
|
-
from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
|
29
28
|
|
30
29
|
from ....modeling import RBLNModel
|
31
30
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
@@ -337,7 +336,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
337
336
|
generate_idx: Optional[torch.Tensor] = None,
|
338
337
|
batch_idx: Optional[int] = None,
|
339
338
|
**kwargs,
|
340
|
-
) -> Union[Tuple,
|
339
|
+
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
341
340
|
vision_feature_layer = (
|
342
341
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
343
342
|
)
|
@@ -378,7 +377,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
378
377
|
inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
|
379
378
|
for batch_idx in range(batch_size):
|
380
379
|
generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
|
381
|
-
logit = self.language_model.
|
380
|
+
logit = self.language_model.prefill_decoder(
|
382
381
|
inputs_embeds=inputs_embeds[batch_idx],
|
383
382
|
batch_idx=batch_idx,
|
384
383
|
cache_position=torch.arange(
|
@@ -390,15 +389,13 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
390
389
|
|
391
390
|
logits.append(logit)
|
392
391
|
logits = torch.cat(logits, dim=0)
|
393
|
-
outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
|
394
392
|
else:
|
395
|
-
|
393
|
+
logits = self.language_model.decoder(
|
396
394
|
inputs_embeds=inputs_embeds,
|
397
395
|
cache_position=cache_position,
|
398
|
-
generate_idx=generate_idx,
|
399
396
|
)
|
400
397
|
|
401
|
-
return
|
398
|
+
return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
|
402
399
|
|
403
400
|
# Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
|
404
401
|
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
|
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
19
19
|
import rebel
|
20
20
|
import torch
|
21
21
|
from rebel.compile_context import CompileContext
|
22
|
-
from transformers import AutoModelForSeq2SeqLM,
|
22
|
+
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
23
23
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
24
24
|
|
25
25
|
from ....modeling import RBLNModel
|
@@ -31,12 +31,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
31
31
|
logger = get_logger(__name__)
|
32
32
|
|
33
33
|
if TYPE_CHECKING:
|
34
|
-
from transformers import
|
35
|
-
AutoFeatureExtractor,
|
36
|
-
AutoProcessor,
|
37
|
-
AutoTokenizer,
|
38
|
-
PretrainedConfig,
|
39
|
-
)
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
|
40
35
|
|
41
36
|
|
42
37
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
@@ -50,9 +45,50 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
50
45
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
51
46
|
mandatory_members = ["main_input_name"]
|
52
47
|
|
53
|
-
def
|
54
|
-
|
55
|
-
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
runtime: rebel.Runtime,
|
51
|
+
batch_size: int,
|
52
|
+
dec_max_seq_len: int,
|
53
|
+
**kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
super().__init__(runtime, **kwargs)
|
56
|
+
self.batch_size = batch_size
|
57
|
+
self.dec_max_seq_len = dec_max_seq_len
|
58
|
+
|
59
|
+
def forward(
|
60
|
+
self,
|
61
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
62
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
63
|
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
64
|
+
cache_position: Optional[torch.Tensor] = None,
|
65
|
+
**kwargs,
|
66
|
+
) -> Tuple[torch.FloatTensor]:
|
67
|
+
batch_size = decoder_input_ids.shape[0]
|
68
|
+
if batch_size != self.batch_size:
|
69
|
+
raise RuntimeError(
|
70
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
71
|
+
)
|
72
|
+
|
73
|
+
if batch_size != cache_position.shape[0]:
|
74
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
75
|
+
|
76
|
+
for b_idx in range(self.batch_size):
|
77
|
+
decoding_step = cache_position[b_idx].item()
|
78
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
79
|
+
raise ValueError(
|
80
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
81
|
+
)
|
82
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
83
|
+
|
84
|
+
lm_logits = super().forward(
|
85
|
+
decoder_input_ids,
|
86
|
+
decoder_attention_mask,
|
87
|
+
attention_mask,
|
88
|
+
cache_position,
|
89
|
+
)
|
90
|
+
|
91
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|
56
92
|
|
57
93
|
|
58
94
|
class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
@@ -72,8 +108,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
72
108
|
auto_model_class = AutoModelForSeq2SeqLM
|
73
109
|
|
74
110
|
def __post_init__(self, **kwargs):
|
75
|
-
|
76
|
-
|
111
|
+
batch_size = self.rbln_config.model_cfg["batch_size"]
|
112
|
+
dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
113
|
+
self.encoder = RBLNRuntimeEncoder(
|
114
|
+
runtime=self.model[0],
|
115
|
+
main_input_name="input_ids",
|
116
|
+
)
|
117
|
+
self.decoder = RBLNRuntimeDecoder(
|
118
|
+
runtime=self.model[1], main_input_name="input_ids", batch_size=batch_size, dec_max_seq_len=dec_max_seq_len
|
119
|
+
)
|
77
120
|
|
78
121
|
@classmethod
|
79
122
|
@torch.inference_mode()
|
@@ -304,46 +347,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
304
347
|
|
305
348
|
def forward(
|
306
349
|
self,
|
307
|
-
|
350
|
+
decoder_input_ids: torch.LongTensor = None,
|
308
351
|
cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
|
309
352
|
**kwargs,
|
310
353
|
) -> Tuple[torch.FloatTensor]:
|
311
354
|
# common decoder
|
312
355
|
cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
|
313
|
-
logits = self.
|
356
|
+
logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
|
314
357
|
|
315
358
|
return Seq2SeqLMOutput(
|
316
359
|
logits=logits,
|
317
360
|
)
|
318
361
|
|
319
|
-
def _forward_decoder(
|
320
|
-
self,
|
321
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
322
|
-
decoder_input_ids: Optional[torch.LongTensor] = None,
|
323
|
-
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
324
|
-
cache_position: Optional[torch.Tensor] = None,
|
325
|
-
**kwargs,
|
326
|
-
) -> Tuple[torch.FloatTensor]:
|
327
|
-
dec_attention_mask = decoder_attention_mask.clone()
|
328
|
-
for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
329
|
-
dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
|
330
|
-
|
331
|
-
decoder_output = self.decoder(
|
332
|
-
input_ids=decoder_input_ids,
|
333
|
-
attention_mask=dec_attention_mask,
|
334
|
-
encoder_attention_mask=attention_mask,
|
335
|
-
cache_position=cache_position,
|
336
|
-
)
|
337
|
-
lm_logits = decoder_output.logits
|
338
|
-
|
339
|
-
return Seq2SeqLMOutput(logits=lm_logits)
|
340
|
-
|
341
362
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
342
363
|
self,
|
343
364
|
inputs_tensor: torch.Tensor,
|
344
365
|
model_kwargs,
|
345
366
|
model_input_name: Optional[str] = None,
|
346
|
-
generation_config: Optional[GenerationConfig] = None,
|
367
|
+
generation_config: Optional["GenerationConfig"] = None,
|
347
368
|
) -> Dict[str, Any]:
|
348
369
|
# 1. get encoder
|
349
370
|
encoder = self.get_encoder()
|
@@ -373,6 +394,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
373
394
|
)
|
374
395
|
|
375
396
|
# 3. make sure that encoder returns `ModelOutput`
|
397
|
+
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
376
398
|
encoder_kwargs["return_dict"] = True
|
377
399
|
encoder_kwargs["output_hidden_states"] = False
|
378
400
|
encoder_kwargs["output_attentions"] = False
|
@@ -459,7 +459,7 @@ class Seq2SeqSelfAttention(nn.Module):
|
|
459
459
|
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
460
460
|
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
461
461
|
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
462
|
-
cache_position
|
462
|
+
cache_position,
|
463
463
|
torch.tensor(1.0, dtype=torch.float32), # scale
|
464
464
|
)
|
465
465
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: optimum-rbln
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.1a5
|
4
4
|
Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
|
5
5
|
Project-URL: Homepage, https://rebellions.ai
|
6
6
|
Project-URL: Documentation, https://docs.rbln.ai
|
@@ -1,7 +1,7 @@
|
|
1
1
|
optimum/rbln/__init__.py,sha256=sLCjJu_MLZEKDOwHIlJP4u4GzGZx-1kqHTYGw5B4xDg,6096
|
2
|
-
optimum/rbln/__version__.py,sha256=
|
2
|
+
optimum/rbln/__version__.py,sha256=J4Eyn4HLzB0UpyosVo-P3LCDkB5knEOS6Nu24mnl5NA,413
|
3
3
|
optimum/rbln/modeling.py,sha256=REImAAKO82CqSNABR-9E1jJEsWch9amSOwOOQhFEYLY,8283
|
4
|
-
optimum/rbln/modeling_base.py,sha256=
|
4
|
+
optimum/rbln/modeling_base.py,sha256=fQ0bI1Bb6GJquRXftmSSN9K-TXLhFltZJ6C-2w43xMg,21193
|
5
5
|
optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
|
6
6
|
optimum/rbln/diffusers/__init__.py,sha256=68FTAMpbbMflm8qiSqfM5J2_gFb3iU3fng6AL0TG47A,2913
|
7
7
|
optimum/rbln/diffusers/modeling_diffusers.py,sha256=E1x-iOKEJCUB6ml0RgtFEVPPk6J6pqEF-JTEyOZzOyc,14928
|
@@ -53,8 +53,8 @@ optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=-nv-sgmHkyHQIoQvF8
|
|
53
53
|
optimum/rbln/transformers/models/clip/__init__.py,sha256=ssJqlEt318ti2QaEakGh_tO3Ap1VSPCVF-ymUuvjAJs,698
|
54
54
|
optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=E1QfVNq1sTCp7uvuha1ZPfXMwvMTkGV9L4oFdmy1w4g,5724
|
55
55
|
optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
|
56
|
-
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=
|
57
|
-
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=
|
56
|
+
optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=eT1fbKDL92yGBXtUKA_JibD4kiRPdf3tAFJHP5nlfH4,36646
|
57
|
+
optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=2OO8MEgFgcl1VPrQXxqkvmRJJEuFdexwu8XqbHDbR6Y,27609
|
58
58
|
optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
|
59
59
|
optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
|
60
60
|
optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
|
@@ -70,7 +70,7 @@ optimum/rbln/transformers/models/llama/__init__.py,sha256=jo_j_eIrHYGNEhR5lb6g3r
|
|
70
70
|
optimum/rbln/transformers/models/llama/llama_architecture.py,sha256=S7MCPfyjG5eUqgaS-QNBB0ApUD6wnb5fR0RHq7k7-pA,728
|
71
71
|
optimum/rbln/transformers/models/llama/modeling_llama.py,sha256=Z3iony7icoFhRQ11MAuFx9UF03uJCsvJQZ6bxHXlrgk,1530
|
72
72
|
optimum/rbln/transformers/models/llava_next/__init__.py,sha256=VLieyWm-UgvuNxw9B38wrL1Jsa09NBDX_ebABmdpTbs,670
|
73
|
-
optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=
|
73
|
+
optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=w_plsUOzxnhkQBhQeUqW9aJqGCvCvLtsx0XNKYjOprU,26203
|
74
74
|
optimum/rbln/transformers/models/midm/__init__.py,sha256=UJSaErsF-z6dZERIS143WTaygffZyzEGqoQ2ZPDiM-c,855
|
75
75
|
optimum/rbln/transformers/models/midm/midm_architecture.py,sha256=mueRmMGX6UplZb0C0RFdUOa9lsNH8YJHV6rYrDLOdlQ,5302
|
76
76
|
optimum/rbln/transformers/models/midm/modeling_midm.py,sha256=GG25BozEZriAL-OPFGpzOjyDtSFB-NfeiLJTDAqxe20,1734
|
@@ -84,8 +84,8 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
|
|
84
84
|
optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
|
85
85
|
optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
|
86
86
|
optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
|
87
|
-
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=
|
88
|
-
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=
|
87
|
+
optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=HG_-8ufRWIls67imU1547V0bk9FUWC0haOBL7eyRV6k,16365
|
88
|
+
optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=_TL4-vpjM9lfRnQUXRFm3mtVdz_h5B23k01uc_XnW5I,18376
|
89
89
|
optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
|
90
90
|
optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=MFs-3yYviV1QqSpsTB2GarTEs9wGH5AYofksLQLMBXg,8043
|
91
91
|
optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=kkjErS42mW2jv5O_xL7BaKobvvqy7BGmYOowKyHakvI,7189
|
@@ -108,7 +108,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
|
|
108
108
|
optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
|
109
109
|
optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
|
110
110
|
optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
|
111
|
-
optimum_rbln-0.2.
|
112
|
-
optimum_rbln-0.2.
|
113
|
-
optimum_rbln-0.2.
|
114
|
-
optimum_rbln-0.2.
|
111
|
+
optimum_rbln-0.2.1a5.dist-info/METADATA,sha256=WSMoEbo3z3TMFB1lqbdJsu4ZeVI9AtewXktRjMk6WQw,5300
|
112
|
+
optimum_rbln-0.2.1a5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
113
|
+
optimum_rbln-0.2.1a5.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
114
|
+
optimum_rbln-0.2.1a5.dist-info/RECORD,,
|
File without changes
|
File without changes
|