optimum-rbln 0.2.1a3__py3-none-any.whl → 0.2.1a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1a3'
15
+ __version__ = version = '0.2.1a5'
16
16
  __version_tuple__ = version_tuple = (0, 2, 1)
@@ -442,8 +442,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
442
442
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
443
443
  return
444
444
 
445
- real_save_dir = self.model_save_dir / self.subfolder
446
- save_directory_path = Path(save_directory)
445
+ # Normalize paths to handle relative paths and symlinks
446
+ real_save_dir = Path(self.model_save_dir).resolve() / self.subfolder
447
+ save_directory_path = Path(save_directory).resolve()
447
448
 
448
449
  if not os.path.exists(real_save_dir) or not os.path.isdir(real_save_dir):
449
450
  raise FileNotFoundError(
@@ -452,13 +453,13 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
452
453
  f"Please ensure the model directory exists and you have the necessary permissions to access it."
453
454
  )
454
455
 
455
- if save_directory_path.absolute() == real_save_dir.absolute():
456
+ if save_directory_path == real_save_dir:
456
457
  raise FileExistsError(
457
458
  f"Cannot save model to '{save_directory}'. This directory already exists and contains the model files."
458
459
  )
459
460
 
460
- # Create a temporary directory next to the target directory
461
- tmp_dir = save_directory + ".tmp"
461
+ # Create a temporary directory with normalized path
462
+ tmp_dir = str(save_directory_path) + ".tmp"
462
463
  try:
463
464
  # Remove temporary directory if it exists from a previous failed attempt
464
465
  if os.path.exists(tmp_dir):
@@ -473,9 +474,9 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
473
474
  self.generation_config.save_pretrained(tmp_dir)
474
475
 
475
476
  # If everything succeeded, atomically replace the target directory
476
- if os.path.exists(save_directory):
477
- shutil.rmtree(save_directory)
478
- os.rename(tmp_dir, save_directory)
477
+ if os.path.exists(save_directory_path):
478
+ shutil.rmtree(save_directory_path)
479
+ os.rename(tmp_dir, save_directory_path)
479
480
 
480
481
  except Exception as e:
481
482
  # Clean up the temporary directory if anything fails
@@ -484,7 +485,7 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
484
485
  raise e # Re-raise the exception after cleanup
485
486
 
486
487
  if push_to_hub:
487
- return super().push_to_hub(save_directory, **kwargs)
488
+ return super().push_to_hub(str(save_directory_path), **kwargs)
488
489
 
489
490
  @staticmethod
490
491
  def _raise_missing_compiled_file_error(missing_files: List[str]):
@@ -427,12 +427,14 @@ class DecoderOnlyModel(nn.Module):
427
427
  cos, sin = None, None
428
428
 
429
429
  # (batch, seq_len) -> (batch,)
430
- seq_positions = cache_position[:, 0]
431
430
  if self.attn_impl == "flash_attn":
431
+ seq_positions = cache_position[:, 0]
432
432
  max_seq_len = past_key_values[0][0].shape[-2]
433
433
  seq_positions = self.convert_sequence_positions_for_flash_attn(
434
434
  seq_positions=seq_positions, max_seq_len=max_seq_len
435
435
  )
436
+ else:
437
+ seq_positions = cache_position[:, :1]
436
438
 
437
439
  present_key_values = past_key_values
438
440
  for layer in self.layers:
@@ -38,34 +38,188 @@ from .decoderonly_architecture import (
38
38
  logger = get_logger()
39
39
 
40
40
  if TYPE_CHECKING:
41
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
42
 
43
43
 
44
44
  class RBLNRuntimeModel(RBLNPytorchRuntime):
45
45
  mandatory_members = ["main_input_name", "embed_tokens"]
46
46
 
47
+ def __init__(
48
+ self,
49
+ runtime: rebel.Runtime,
50
+ phase: str,
51
+ batch_size: int,
52
+ dec_attn_mask: torch.Tensor,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(runtime, **kwargs)
56
+ self.phase = phase
57
+ self.batch_size = batch_size
58
+
59
+ # shared tensor between prefill and decode phase
60
+ self.dec_attn_mask = dec_attn_mask
61
+
62
+ if self.phase == "prefill":
63
+ vocab_size = kwargs.pop("vocab_size")
64
+ self.max_seq_len = kwargs.pop("max_seq_len")
65
+ self.prefill_chunk_size = kwargs.pop("prefill_chunk_size")
66
+ self.output_size = [1, 1, vocab_size]
67
+ self.causal_mask = 1 - torch.triu(
68
+ torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
69
+ )
70
+
47
71
  def forward(
48
72
  self,
49
- input_ids: torch.LongTensor,
50
- inputs_embeds: torch.Tensor,
51
- attention_mask: torch.Tensor,
52
- cache_position: torch.Tensor,
53
- **kwargs,
73
+ input_ids: Optional[torch.LongTensor] = None,
74
+ inputs_embeds: Optional[torch.Tensor] = None,
75
+ cache_position: torch.Tensor = None,
76
+ attention_mask: Optional[torch.Tensor] = None,
77
+ batch_idx: Optional[int] = None,
54
78
  ):
79
+ if input_ids is None and inputs_embeds is None:
80
+ raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
81
+
55
82
  if inputs_embeds is None:
56
- inp = input_ids
83
+ inputs = input_ids
57
84
  if self.embed_tokens is not None:
58
- inp = self.embed_tokens(inp)
85
+ inputs = self.embed_tokens(inputs)
59
86
  else:
60
- inp = inputs_embeds
87
+ inputs = inputs_embeds
61
88
 
62
- return super().forward(
63
- inp,
64
- attention_mask,
89
+ if self.phase == "decode":
90
+ return self.decode_forward(
91
+ inputs,
92
+ cache_position,
93
+ attention_mask=attention_mask,
94
+ )
95
+ else:
96
+ return self.prefill_forward(inputs, cache_position, attention_mask, batch_idx)
97
+
98
+ def decode_forward(
99
+ self,
100
+ inputs: torch.Tensor,
101
+ cache_position: torch.Tensor = None,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ ) -> torch.FloatTensor:
104
+ batch_size = inputs.shape[0]
105
+ if batch_size != self.batch_size:
106
+ raise RuntimeError(
107
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
108
+ )
109
+
110
+ if batch_size != cache_position.shape[0]:
111
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
112
+
113
+ if attention_mask is None:
114
+ for b_idx in range(batch_size):
115
+ decoding_step = cache_position[b_idx].item()
116
+ if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
117
+ raise ValueError(
118
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
119
+ )
120
+ self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
121
+
122
+ logits = super().forward(
123
+ inputs,
124
+ self.dec_attn_mask if attention_mask is None else attention_mask,
65
125
  cache_position,
66
- **kwargs,
67
126
  )
68
127
 
128
+ return logits
129
+
130
+ def prefill_forward(
131
+ self,
132
+ inputs: torch.Tensor,
133
+ cache_position: torch.Tensor = None,
134
+ attention_mask: Optional[torch.Tensor] = None,
135
+ batch_idx: int = None,
136
+ ) -> torch.FloatTensor:
137
+ """
138
+ Performs chunked prefill for efficient KV-cache updates and memory optimization.
139
+ Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
140
+ and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
141
+ """
142
+
143
+ if batch_idx is None or batch_idx >= self.batch_size:
144
+ raise RuntimeError(
145
+ f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
146
+ )
147
+
148
+ # Handle continuous batching in a compiled graph by extracting valid inputs
149
+ # If an attention mask is provided, select only the valid (non-masked) inputs
150
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
151
+
152
+ query_length = inputs.shape[1]
153
+ if query_length > self.max_seq_len:
154
+ raise ValueError(
155
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
156
+ )
157
+
158
+ # Initialize attention mask for chunked processing
159
+ chunked_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
160
+
161
+ # Buffer for storing output logits
162
+ out_buffers = [
163
+ torch.empty(
164
+ size=self.output_size,
165
+ dtype=torch.float32,
166
+ device="cpu",
167
+ )
168
+ ]
169
+
170
+ # Process input in chunks of size `prefill_chunk_size`
171
+ for step in range(0, query_length, self.prefill_chunk_size):
172
+ # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
173
+ if (step + self.prefill_chunk_size) > query_length:
174
+ padding_size = step + self.prefill_chunk_size - query_length
175
+ # inputs_embeds
176
+ if inputs.dim() == 3:
177
+ inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
178
+ # inputs_ids
179
+ else:
180
+ inputs = torch.nn.functional.pad(inputs, (0, padding_size))
181
+
182
+ cache_position = torch.cat(
183
+ [
184
+ cache_position,
185
+ torch.arange(
186
+ query_length,
187
+ step + self.prefill_chunk_size,
188
+ dtype=torch.int32,
189
+ ).unsqueeze(0),
190
+ ],
191
+ dim=-1,
192
+ )
193
+
194
+ # Extract the current chunk of inputs and cache positions
195
+ input_chunk = inputs[:, step : step + self.prefill_chunk_size]
196
+ cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
197
+
198
+ # Update attention mask to ensure proper causal behavior
199
+ if step >= self.prefill_chunk_size:
200
+ chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
201
+ chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
202
+
203
+ # Define batch position and query position
204
+ batch_position = torch.tensor(batch_idx, dtype=torch.int16)
205
+ query_position = torch.tensor((query_length - 1) % self.prefill_chunk_size, dtype=torch.int16)
206
+
207
+ # Forward pass for the current chunk
208
+ logits = super().forward(
209
+ input_chunk,
210
+ chunked_attention_mask,
211
+ cache_pos_chunk,
212
+ batch_position,
213
+ query_position,
214
+ out=out_buffers,
215
+ )
216
+
217
+ # Update decoder attention mask with processed KV-cache length from prefill phase
218
+ self.dec_attn_mask[batch_idx].fill_(0)
219
+ self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
220
+
221
+ return logits
222
+
69
223
 
70
224
  @dataclass
71
225
  class RBLNDecoderOnlyOutput(ModelOutput):
@@ -103,13 +257,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
103
257
  self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
104
258
  self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
105
259
 
106
- self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
107
- self.causal_mask = 1 - torch.triu(
108
- torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
109
- )
110
- self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
111
- self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
112
-
113
260
  main_input_name = self.main_input_name
114
261
  if self.rbln_config.model_cfg["use_inputs_embeds"]:
115
262
  main_input_name = "inputs_embeds"
@@ -124,11 +271,25 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
124
271
  else:
125
272
  self.embed_tokens = None
126
273
 
274
+ dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
127
275
  self.prefill_decoder = RBLNRuntimeModel(
128
- runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
276
+ runtime=self.model[0],
277
+ main_input_name=main_input_name,
278
+ embed_tokens=self.embed_tokens,
279
+ phase="prefill",
280
+ batch_size=self.batch_size,
281
+ dec_attn_mask=dec_attn_mask,
282
+ vocab_size=self.config.vocab_size,
283
+ max_seq_len=self.max_seq_len,
284
+ prefill_chunk_size=self.prefill_chunk_size,
129
285
  )
130
286
  self.decoder = RBLNRuntimeModel(
131
- runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
287
+ runtime=self.model[1],
288
+ main_input_name=main_input_name,
289
+ embed_tokens=self.embed_tokens,
290
+ phase="decode",
291
+ batch_size=self.batch_size,
292
+ dec_attn_mask=dec_attn_mask,
132
293
  )
133
294
 
134
295
  @classmethod
@@ -155,7 +316,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
155
316
  def get_quantized_model(
156
317
  cls,
157
318
  model_id: str,
158
- config: Optional[PretrainedConfig] = None,
319
+ config: Optional["PretrainedConfig"] = None,
159
320
  use_auth_token: Optional[Union[bool, str]] = None,
160
321
  revision: Optional[str] = None,
161
322
  force_download: bool = False,
@@ -496,32 +657,33 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
496
657
  generate_idx: Optional[torch.Tensor] = None,
497
658
  **kwargs,
498
659
  ) -> Tuple[torch.FloatTensor]:
499
- # prefll
660
+ """
661
+ Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
662
+ For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
663
+ A for-loop ensures synchronization with the HuggingFace generate API.
664
+ The decoder stage operates as usual, processing inputs in batch mode.
665
+ """
666
+ # Prefll
500
667
  if cache_position is None:
501
668
  logits = []
502
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
503
- batch_size = input_tensors.shape[0]
669
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
670
+ batch_size = inputs.shape[0]
504
671
 
505
672
  for b_idx in range(batch_size):
506
- # Transform inputs as vllm format
507
- if attention_mask is not None:
508
- input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
509
- else:
510
- input_tensor = input_tensors[b_idx : b_idx + 1]
511
-
512
673
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
513
-
514
- logit = self._forward_prefill(
515
- input_ids=input_tensor if inputs_embeds is None else None,
516
- inputs_embeds=input_tensor if inputs_embeds is not None else None,
674
+ logit = self.prefill_decoder(
675
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
676
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
677
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
517
678
  cache_position=cache_position,
518
679
  batch_idx=b_idx,
519
680
  )
520
681
  logits.append(logit)
682
+
521
683
  logits = torch.cat(logits, dim=0)
522
- # decoder
684
+ # Decoder
523
685
  else:
524
- logits = self._forward_decoder(
686
+ logits = self.decoder(
525
687
  input_ids=input_ids,
526
688
  inputs_embeds=inputs_embeds,
527
689
  cache_position=cache_position,
@@ -531,119 +693,3 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
531
693
  logits=logits,
532
694
  generate_idx=generate_idx,
533
695
  )
534
-
535
- def _forward_prefill(
536
- self,
537
- input_ids: torch.LongTensor = None,
538
- inputs_embeds: torch.Tensor = None,
539
- cache_position: torch.Tensor = None,
540
- batch_idx: int = None,
541
- ) -> torch.FloatTensor:
542
- if batch_idx is None or batch_idx >= self.batch_size:
543
- raise RuntimeError(
544
- f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
545
- )
546
-
547
- out_buffers = [
548
- torch.empty(
549
- size=[
550
- 1,
551
- 1,
552
- self.config.vocab_size,
553
- ],
554
- dtype=torch.float32,
555
- device="cpu",
556
- )
557
- ]
558
-
559
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
560
- query_length = input_tensors.shape[1]
561
- if query_length > self.max_seq_len:
562
- raise ValueError(
563
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
564
- )
565
-
566
- _attention_mask = self.prefill_attention_mask.clone()
567
-
568
- for step in range(0, query_length, self.prefill_chunk_size):
569
- # pad input_tensors & cache_position for prefill_chunk
570
- if (step + self.prefill_chunk_size) > query_length:
571
- pad_to_chunk = step + self.prefill_chunk_size - query_length
572
- if inputs_embeds is not None:
573
- input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
574
- else:
575
- input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
576
-
577
- cache_position = torch.cat(
578
- [
579
- cache_position,
580
- torch.arange(
581
- query_length,
582
- step + self.prefill_chunk_size,
583
- dtype=torch.int32,
584
- ).unsqueeze(0),
585
- ],
586
- dim=-1,
587
- )
588
-
589
- # slice input_tensor & cache_position with prefill_chunk_size
590
- _input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
591
- _cache_position = cache_position[:, step : step + self.prefill_chunk_size]
592
-
593
- # update attention_mask
594
- if step >= self.prefill_chunk_size:
595
- _attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
596
- _attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
597
-
598
- query_position = (query_length - 1) % self.prefill_chunk_size
599
-
600
- logits = self.prefill_decoder(
601
- input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
602
- inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
603
- attention_mask=_attention_mask.contiguous(),
604
- cache_position=_cache_position.contiguous(),
605
- batch_position=torch.tensor(batch_idx, dtype=torch.int16),
606
- query_position=torch.tensor(query_position, dtype=torch.int16),
607
- out=out_buffers,
608
- )
609
-
610
- # update decoder_attn_mask with preprocessed kv-cache length in prefill phase
611
- self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
612
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
613
-
614
- return logits
615
-
616
- def _forward_decoder(
617
- self,
618
- input_ids: torch.LongTensor = None,
619
- inputs_embeds: torch.Tensor = None,
620
- cache_position: torch.Tensor = None,
621
- ) -> torch.FloatTensor:
622
- input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
623
- if input_tensors is None:
624
- raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
625
-
626
- batch_size = input_tensors.shape[0]
627
- if batch_size != self.batch_size:
628
- raise RuntimeError(
629
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
630
- )
631
-
632
- if batch_size != cache_position.shape[0]:
633
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
634
-
635
- for b_idx in range(batch_size):
636
- decoding_step = cache_position[b_idx].item()
637
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
638
- raise ValueError(
639
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
640
- )
641
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
642
- logits = self.decoder(
643
- input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
644
- inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
645
- attention_mask=self.dec_attn_mask.contiguous(),
646
- cache_position=cache_position.contiguous(),
647
- )
648
-
649
- return logits
@@ -25,7 +25,6 @@ from transformers import (
25
25
  PreTrainedModel,
26
26
  )
27
27
  from transformers.modeling_outputs import BaseModelOutputWithPooling
28
- from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
29
28
 
30
29
  from ....modeling import RBLNModel
31
30
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
@@ -337,7 +336,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
337
336
  generate_idx: Optional[torch.Tensor] = None,
338
337
  batch_idx: Optional[int] = None,
339
338
  **kwargs,
340
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
339
+ ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
341
340
  vision_feature_layer = (
342
341
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
343
342
  )
@@ -378,7 +377,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
378
377
  inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
379
378
  for batch_idx in range(batch_size):
380
379
  generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
381
- logit = self.language_model._forward_prefill(
380
+ logit = self.language_model.prefill_decoder(
382
381
  inputs_embeds=inputs_embeds[batch_idx],
383
382
  batch_idx=batch_idx,
384
383
  cache_position=torch.arange(
@@ -390,15 +389,13 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
390
389
 
391
390
  logits.append(logit)
392
391
  logits = torch.cat(logits, dim=0)
393
- outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
394
392
  else:
395
- outputs: RBLNDecoderOnlyOutput = self.language_model(
393
+ logits = self.language_model.decoder(
396
394
  inputs_embeds=inputs_embeds,
397
395
  cache_position=cache_position,
398
- generate_idx=generate_idx,
399
396
  )
400
397
 
401
- return outputs
398
+ return RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
402
399
 
403
400
  # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
404
401
  def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
19
19
  import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
- from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
22
+ from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
23
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
24
 
25
25
  from ....modeling import RBLNModel
@@ -31,12 +31,7 @@ from ....utils.runtime_utils import RBLNPytorchRuntime
31
31
  logger = get_logger(__name__)
32
32
 
33
33
  if TYPE_CHECKING:
34
- from transformers import (
35
- AutoFeatureExtractor,
36
- AutoProcessor,
37
- AutoTokenizer,
38
- PretrainedConfig,
39
- )
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
40
35
 
41
36
 
42
37
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
@@ -50,9 +45,50 @@ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
50
45
  class RBLNRuntimeDecoder(RBLNPytorchRuntime):
51
46
  mandatory_members = ["main_input_name"]
52
47
 
53
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
54
- outputs = super().forward(*args, **kwargs)
55
- return Seq2SeqLMOutput(logits=outputs)
48
+ def __init__(
49
+ self,
50
+ runtime: rebel.Runtime,
51
+ batch_size: int,
52
+ dec_max_seq_len: int,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ super().__init__(runtime, **kwargs)
56
+ self.batch_size = batch_size
57
+ self.dec_max_seq_len = dec_max_seq_len
58
+
59
+ def forward(
60
+ self,
61
+ decoder_input_ids: Optional[torch.LongTensor] = None,
62
+ attention_mask: Optional[torch.FloatTensor] = None,
63
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
64
+ cache_position: Optional[torch.Tensor] = None,
65
+ **kwargs,
66
+ ) -> Tuple[torch.FloatTensor]:
67
+ batch_size = decoder_input_ids.shape[0]
68
+ if batch_size != self.batch_size:
69
+ raise RuntimeError(
70
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
71
+ )
72
+
73
+ if batch_size != cache_position.shape[0]:
74
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
75
+
76
+ for b_idx in range(self.batch_size):
77
+ decoding_step = cache_position[b_idx].item()
78
+ if not (0 <= decoding_step < self.dec_max_seq_len):
79
+ raise ValueError(
80
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
81
+ )
82
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
83
+
84
+ lm_logits = super().forward(
85
+ decoder_input_ids,
86
+ decoder_attention_mask,
87
+ attention_mask,
88
+ cache_position,
89
+ )
90
+
91
+ return Seq2SeqLMOutput(logits=lm_logits)
56
92
 
57
93
 
58
94
  class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
@@ -72,8 +108,15 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
72
108
  auto_model_class = AutoModelForSeq2SeqLM
73
109
 
74
110
  def __post_init__(self, **kwargs):
75
- self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
76
- self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
111
+ batch_size = self.rbln_config.model_cfg["batch_size"]
112
+ dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
113
+ self.encoder = RBLNRuntimeEncoder(
114
+ runtime=self.model[0],
115
+ main_input_name="input_ids",
116
+ )
117
+ self.decoder = RBLNRuntimeDecoder(
118
+ runtime=self.model[1], main_input_name="input_ids", batch_size=batch_size, dec_max_seq_len=dec_max_seq_len
119
+ )
77
120
 
78
121
  @classmethod
79
122
  @torch.inference_mode()
@@ -304,46 +347,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
304
347
 
305
348
  def forward(
306
349
  self,
307
- input_ids: torch.LongTensor = None,
350
+ decoder_input_ids: torch.LongTensor = None,
308
351
  cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
309
352
  **kwargs,
310
353
  ) -> Tuple[torch.FloatTensor]:
311
354
  # common decoder
312
355
  cache_position = torch.full((self.rbln_config.model_cfg["batch_size"], 1), cache_position, dtype=torch.int32)
313
- logits = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs).logits
356
+ logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
314
357
 
315
358
  return Seq2SeqLMOutput(
316
359
  logits=logits,
317
360
  )
318
361
 
319
- def _forward_decoder(
320
- self,
321
- attention_mask: Optional[torch.FloatTensor] = None,
322
- decoder_input_ids: Optional[torch.LongTensor] = None,
323
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
324
- cache_position: Optional[torch.Tensor] = None,
325
- **kwargs,
326
- ) -> Tuple[torch.FloatTensor]:
327
- dec_attention_mask = decoder_attention_mask.clone()
328
- for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
329
- dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
330
-
331
- decoder_output = self.decoder(
332
- input_ids=decoder_input_ids,
333
- attention_mask=dec_attention_mask,
334
- encoder_attention_mask=attention_mask,
335
- cache_position=cache_position,
336
- )
337
- lm_logits = decoder_output.logits
338
-
339
- return Seq2SeqLMOutput(logits=lm_logits)
340
-
341
362
  def _prepare_encoder_decoder_kwargs_for_generation(
342
363
  self,
343
364
  inputs_tensor: torch.Tensor,
344
365
  model_kwargs,
345
366
  model_input_name: Optional[str] = None,
346
- generation_config: Optional[GenerationConfig] = None,
367
+ generation_config: Optional["GenerationConfig"] = None,
347
368
  ) -> Dict[str, Any]:
348
369
  # 1. get encoder
349
370
  encoder = self.get_encoder()
@@ -373,6 +394,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
373
394
  )
374
395
 
375
396
  # 3. make sure that encoder returns `ModelOutput`
397
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
376
398
  encoder_kwargs["return_dict"] = True
377
399
  encoder_kwargs["output_hidden_states"] = False
378
400
  encoder_kwargs["output_attentions"] = False
@@ -459,7 +459,7 @@ class Seq2SeqSelfAttention(nn.Module):
459
459
  ), # Unsqueeze group axis since CustomKernel expects it for group query attention
460
460
  past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
461
461
  past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
462
- cache_position.squeeze(1),
462
+ cache_position,
463
463
  torch.tensor(1.0, dtype=torch.float32), # scale
464
464
  )
465
465
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: optimum-rbln
3
- Version: 0.2.1a3
3
+ Version: 0.2.1a5
4
4
  Summary: Optimum RBLN is the interface between the Hugging Face Transformers and Diffusers libraries and RBLN accelerators. It provides a set of tools enabling easy model loading and inference on single and multiple rbln device settings for different downstream tasks.
5
5
  Project-URL: Homepage, https://rebellions.ai
6
6
  Project-URL: Documentation, https://docs.rbln.ai
@@ -1,7 +1,7 @@
1
1
  optimum/rbln/__init__.py,sha256=sLCjJu_MLZEKDOwHIlJP4u4GzGZx-1kqHTYGw5B4xDg,6096
2
- optimum/rbln/__version__.py,sha256=Qa8tLTuiehljsgp_ibSY6aee43cZYh5J_fQ5zMTZ6SA,413
2
+ optimum/rbln/__version__.py,sha256=J4Eyn4HLzB0UpyosVo-P3LCDkB5knEOS6Nu24mnl5NA,413
3
3
  optimum/rbln/modeling.py,sha256=REImAAKO82CqSNABR-9E1jJEsWch9amSOwOOQhFEYLY,8283
4
- optimum/rbln/modeling_base.py,sha256=_5M8hVySDwCJ6qfeku2_nJAPu_5JLfAUu3HO1bc3ALM,21098
4
+ optimum/rbln/modeling_base.py,sha256=fQ0bI1Bb6GJquRXftmSSN9K-TXLhFltZJ6C-2w43xMg,21193
5
5
  optimum/rbln/modeling_config.py,sha256=7104bxmrvKW4Q6XTruQayiIGl8GHDFmPkJ3cknMIInE,11335
6
6
  optimum/rbln/diffusers/__init__.py,sha256=68FTAMpbbMflm8qiSqfM5J2_gFb3iU3fng6AL0TG47A,2913
7
7
  optimum/rbln/diffusers/modeling_diffusers.py,sha256=E1x-iOKEJCUB6ml0RgtFEVPPk6J6pqEF-JTEyOZzOyc,14928
@@ -53,8 +53,8 @@ optimum/rbln/transformers/models/bert/modeling_bert.py,sha256=-nv-sgmHkyHQIoQvF8
53
53
  optimum/rbln/transformers/models/clip/__init__.py,sha256=ssJqlEt318ti2QaEakGh_tO3Ap1VSPCVF-ymUuvjAJs,698
54
54
  optimum/rbln/transformers/models/clip/modeling_clip.py,sha256=E1QfVNq1sTCp7uvuha1ZPfXMwvMTkGV9L4oFdmy1w4g,5724
55
55
  optimum/rbln/transformers/models/decoderonly/__init__.py,sha256=pDogsdpJKKB5rqnVFrRjwfhUvOSV-jZ3oARMsqSvOOQ,665
56
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=BjQHwoPZfM-KUQzxm4AU-PdmoMgLxnCG6kfSpGjUvrk,36578
57
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=mAgRRMGVHvTUjJBDlmUOjNhSNjprKSD7tLeFknrx0Rw,25810
56
+ optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py,sha256=eT1fbKDL92yGBXtUKA_JibD4kiRPdf3tAFJHP5nlfH4,36646
57
+ optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py,sha256=2OO8MEgFgcl1VPrQXxqkvmRJJEuFdexwu8XqbHDbR6Y,27609
58
58
  optimum/rbln/transformers/models/dpt/__init__.py,sha256=gP1tkR3XMNlHq1GT87ugIVvb2o_1eAUg1JaniXjy1Lw,651
59
59
  optimum/rbln/transformers/models/dpt/modeling_dpt.py,sha256=ZsS2SOiqcA4azULB-WFEMQZbgIoOyVUKqVKqrw_tWzA,3430
60
60
  optimum/rbln/transformers/models/exaone/__init__.py,sha256=zYH_5tVa8-juEdsOIky7I33WSC3Zuhoq1upI0OHYeVw,859
@@ -70,7 +70,7 @@ optimum/rbln/transformers/models/llama/__init__.py,sha256=jo_j_eIrHYGNEhR5lb6g3r
70
70
  optimum/rbln/transformers/models/llama/llama_architecture.py,sha256=S7MCPfyjG5eUqgaS-QNBB0ApUD6wnb5fR0RHq7k7-pA,728
71
71
  optimum/rbln/transformers/models/llama/modeling_llama.py,sha256=Z3iony7icoFhRQ11MAuFx9UF03uJCsvJQZ6bxHXlrgk,1530
72
72
  optimum/rbln/transformers/models/llava_next/__init__.py,sha256=VLieyWm-UgvuNxw9B38wrL1Jsa09NBDX_ebABmdpTbs,670
73
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=_8zKsI-Kj4bbsPLnERJqg-0oC6EyAWrmnxvszsAtRaA,26398
73
+ optimum/rbln/transformers/models/llava_next/modeling_llava_next.py,sha256=w_plsUOzxnhkQBhQeUqW9aJqGCvCvLtsx0XNKYjOprU,26203
74
74
  optimum/rbln/transformers/models/midm/__init__.py,sha256=UJSaErsF-z6dZERIS143WTaygffZyzEGqoQ2ZPDiM-c,855
75
75
  optimum/rbln/transformers/models/midm/midm_architecture.py,sha256=mueRmMGX6UplZb0C0RFdUOa9lsNH8YJHV6rYrDLOdlQ,5302
76
76
  optimum/rbln/transformers/models/midm/modeling_midm.py,sha256=GG25BozEZriAL-OPFGpzOjyDtSFB-NfeiLJTDAqxe20,1734
@@ -84,8 +84,8 @@ optimum/rbln/transformers/models/qwen2/__init__.py,sha256=RAMWc21W_2I6DH9xBjeNxP
84
84
  optimum/rbln/transformers/models/qwen2/modeling_qwen2.py,sha256=9-aFDvjMzPNUyGOz0qo33RE18bUFGYZ3Wt_68zb5uJY,1530
85
85
  optimum/rbln/transformers/models/qwen2/qwen2_architecture.py,sha256=XlNAMYAcDLohnSAhIFGKOPuCB5XLgzYs5ABWdeQSaZs,720
86
86
  optimum/rbln/transformers/models/seq2seq/__init__.py,sha256=EmEMV4rOYqKyruX85d0fR73-b8N6BSD6CPcbpYdBuVk,651
87
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=2hkCPvaiyS16zdtUiJKhvpk1qJfsXVLrAQPgAtixCg0,15426
88
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=15yoF-wyhcLcK-Z2MOUmyPlkOMNTVOJ013uBepqtpxA,18387
87
+ optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py,sha256=HG_-8ufRWIls67imU1547V0bk9FUWC0haOBL7eyRV6k,16365
88
+ optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py,sha256=_TL4-vpjM9lfRnQUXRFm3mtVdz_h5B23k01uc_XnW5I,18376
89
89
  optimum/rbln/transformers/models/t5/__init__.py,sha256=1skR1RmnG62WTAP3-F5P1x-V_ReFhMyirH3u56vWwvc,675
90
90
  optimum/rbln/transformers/models/t5/modeling_t5.py,sha256=MFs-3yYviV1QqSpsTB2GarTEs9wGH5AYofksLQLMBXg,8043
91
91
  optimum/rbln/transformers/models/t5/t5_architecture.py,sha256=kkjErS42mW2jv5O_xL7BaKobvvqy7BGmYOowKyHakvI,7189
@@ -108,7 +108,7 @@ optimum/rbln/utils/model_utils.py,sha256=DfD_Z2qvZHqcddXqnzTM1AN8khanj3-DXK2lJvV
108
108
  optimum/rbln/utils/runtime_utils.py,sha256=5-DYniyP59nx-mrrbi7AqA77L85b4Cm5oLpaxidSyss,3699
109
109
  optimum/rbln/utils/save_utils.py,sha256=hG5uOtYmecSXZuGTvCXsTM-SiyZpr5q3InUGCCq_jzQ,3619
110
110
  optimum/rbln/utils/submodule.py,sha256=oZoGrItB8WqY4i-K9WJPlLlcLohc1YGB9OHB8_XZw3A,4071
111
- optimum_rbln-0.2.1a3.dist-info/METADATA,sha256=umGg7JkKhTcNc5AOyzubqzpoPXnGY1WosDi48dfAROw,5300
112
- optimum_rbln-0.2.1a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
113
- optimum_rbln-0.2.1a3.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
114
- optimum_rbln-0.2.1a3.dist-info/RECORD,,
111
+ optimum_rbln-0.2.1a5.dist-info/METADATA,sha256=WSMoEbo3z3TMFB1lqbdJsu4ZeVI9AtewXktRjMk6WQw,5300
112
+ optimum_rbln-0.2.1a5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
113
+ optimum_rbln-0.2.1a5.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
114
+ optimum_rbln-0.2.1a5.dist-info/RECORD,,