optimum-rbln 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. optimum/rbln/__init__.py +10 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +0 -2
  4. optimum/rbln/diffusers/models/controlnet.py +0 -6
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +0 -3
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +18 -20
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -20
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +19 -34
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +20 -35
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +12 -13
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +13 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +13 -14
  15. optimum/rbln/modeling_alias.py +4 -9
  16. optimum/rbln/modeling_base.py +105 -139
  17. optimum/rbln/modeling_config.py +51 -0
  18. optimum/rbln/transformers/__init__.py +8 -0
  19. optimum/rbln/transformers/models/__init__.py +4 -1
  20. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  21. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  22. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  23. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  24. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  25. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +172 -100
  27. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  28. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  29. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  30. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  31. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  32. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  33. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +148 -152
  34. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -0
  35. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  36. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  37. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  38. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  39. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  40. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  41. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  42. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  43. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  44. optimum/rbln/transformers/models/whisper/modeling_whisper.py +37 -12
  45. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  46. optimum/rbln/utils/import_utils.py +14 -0
  47. optimum/rbln/utils/logging.py +1 -1
  48. optimum/rbln/utils/runtime_utils.py +1 -1
  49. optimum/rbln/utils/timer_utils.py +26 -2
  50. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +4 -3
  51. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/RECORD +54 -44
  52. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  53. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/entry_points.txt +0 -0
  54. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Callable
27
+
28
+ from ....modeling_config import RBLNConfig
29
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
30
+ from .exaone_architecture import ExaoneForCausalLMWrapper
31
+ from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
32
+
33
+
34
+ logger = logging.getLogger(__name__)
35
+ if TYPE_CHECKING:
36
+ from transformers import (
37
+ PreTrainedModel,
38
+ )
39
+
40
+
41
+ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
42
+ """
43
+ The Exaone Model transformer with a language modeling head on top (linear layer with weights tied to the input
44
+ embeddings).
45
+
46
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
47
+ library implements for all its model.
48
+
49
+ It implements the methods to convert a pre-trained transformers Exaone model into a RBLN transformer model by:
50
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
51
+ - compiling the resulting graph using the RBLN compiler.
52
+
53
+ """
54
+
55
+ @classmethod
56
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
57
+ rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
58
+ return ExaoneForCausalLMWrapper(model, rbln_max_seq_len).eval()
59
+
60
+ def __getattr__(self, __name: str) -> Any:
61
+ """This is the key method to implement RBLN-Exaone.
62
+
63
+ Returns:
64
+ Any: Exaone's corresponding method
65
+ """
66
+
67
+ def redirect(func):
68
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
69
+
70
+ val = getattr(ExaoneForCausalLM, __name)
71
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
72
+ return redirect(val)
73
+ return val
74
+
75
+ @classmethod
76
+ def from_pretrained(cls, *args, **kwargs):
77
+ kwargs.setdefault("trust_remote_code", True)
78
+ return super().from_pretrained(*args, **kwargs)
@@ -114,7 +114,7 @@ class LoopProjector:
114
114
  return self.forward(*args, **kwds)
115
115
 
116
116
  def __repr__(self) -> str:
117
- return repr(self.vision_tower)
117
+ return repr(self.multi_modal_projector)
118
118
 
119
119
 
120
120
  class RBLNLlavaNextForConditionalGeneration(RBLNModel):
@@ -228,29 +228,26 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
228
228
  pixel_values=None,
229
229
  image_sizes=None,
230
230
  attention_mask=None,
231
- past_cached_length=None,
231
+ generate_idx=None,
232
232
  **kwargs,
233
233
  ):
234
234
  # Prepare HF generation
235
- is_prefill_phase = past_cached_length is None
235
+ is_prefill_phase = generate_idx is None
236
236
  batch_size = input_ids.shape[0]
237
237
 
238
238
  model_inputs = self.language_model.prepare_inputs_for_generation(
239
239
  input_ids=input_ids,
240
240
  inputs_embeds=inputs_embeds,
241
- past_cached_length=past_cached_length, # Not affect
241
+ generate_idx=generate_idx, # Not affect
242
242
  attention_mask=attention_mask,
243
243
  **kwargs,
244
244
  )
245
245
 
246
246
  if is_prefill_phase:
247
- model_inputs["past_cached_length"] = torch.zeros((batch_size, 1), dtype=torch.int32)
248
- else:
249
- model_inputs["past_cached_length"] = past_cached_length + 1
247
+ model_inputs["generate_idx"] = torch.zeros((batch_size, 1), dtype=torch.int32)
250
248
 
251
249
  model_inputs.update(
252
250
  {
253
- # "position_ids": position_ids or cache_positions,
254
251
  "pixel_values": pixel_values,
255
252
  "image_sizes": image_sizes,
256
253
  "attention_mask": attention_mask,
@@ -264,43 +261,28 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
264
261
  model_kwargs: Dict[str, Any],
265
262
  **kwargs,
266
263
  ) -> Dict[str, Any]:
267
- # update past_cached_length
268
- model_kwargs["past_cached_length"] = outputs.past_cached_length
264
+ # update generate_idx
265
+ model_kwargs["generate_idx"] = outputs.generate_idx
269
266
 
270
267
  return model_kwargs
271
268
 
272
- def _merge_vllm_multimodal_embeddings(
269
+ def text_embedding(
273
270
  self,
274
- input_ids: torch.Tensor,
275
- inputs_embeds: torch.Tensor,
276
- multimodal_embeddings: torch.Tensor,
277
- placeholder_token_id: int,
271
+ input_ids: torch.LongTensor,
278
272
  ) -> torch.Tensor:
279
- mask = input_ids == placeholder_token_id
280
- num_expected_tokens = mask.sum().item()
281
- assert isinstance(num_expected_tokens, int)
282
-
283
- if multimodal_embeddings.shape[0] != num_expected_tokens:
284
- raise ValueError(
285
- f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
286
- f"multimodal tokens to {num_expected_tokens} placeholders"
287
- )
273
+ for_inputs_embeds_ids = input_ids.clone()
274
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
275
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
288
276
 
289
- inputs_embeds[mask] = multimodal_embeddings
290
277
  return inputs_embeds
291
278
 
292
- def _embed(
279
+ def image_embedding(
293
280
  self,
294
- input_ids: torch.LongTensor,
295
281
  image_sizes: torch.LongTensor,
296
- attention_mask: torch.Tensor,
297
282
  pixel_values: torch.FloatTensor,
298
283
  vision_feature_layer: int,
299
284
  vision_feature_select_strategy: str,
300
- cache_position: torch.Tensor,
301
- past_cached_length: torch.Tensor,
302
- from_vllm_prefill: bool = False,
303
- ) -> List[torch.Tensor]:
285
+ ) -> torch.Tensor:
304
286
  vision_feature_layer = (
305
287
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
306
288
  )
@@ -310,159 +292,173 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
310
292
  else self.config.vision_feature_select_strategy
311
293
  )
312
294
 
313
- # 1. Extract the input embeddings
314
- # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
315
- for_inputs_embeds_ids = input_ids.clone()
316
- for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
317
-
318
- inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
319
-
320
- # 2. Merge text and images
321
- if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
322
- # ! infer image_num_patches from image_sizes
323
- image_num_patches = [
324
- image_size_to_num_patches(
325
- image_size=imsize,
326
- grid_pinpoints=self.config.image_grid_pinpoints,
327
- patch_size=self.config.vision_config.image_size,
328
- )
329
- for imsize in image_sizes
330
- ]
331
- # figure out if pixel_values is concatenated or stacked
332
- if pixel_values.dim() == 5:
333
- # stacking when input is (batch_size, num_patches, num_channels, height, width)
334
- _pixel_values_list = [
335
- pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
336
- ]
337
- pixel_values = torch.cat(_pixel_values_list, dim=0)
338
- elif pixel_values.dim() != 4:
339
- # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
340
- raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
341
-
342
- image_features = self.vision_tower(pixel_values, output_hidden_states=True)
343
- selected_image_feature = image_features.hidden_states[vision_feature_layer]
344
-
345
- if vision_feature_select_strategy == "default":
346
- selected_image_feature = selected_image_feature[:, 1:]
347
- elif vision_feature_select_strategy == "full":
348
- selected_image_feature = selected_image_feature
349
-
350
- image_features = self.multi_modal_projector(selected_image_feature)
351
- image_features = torch.split(image_features, image_num_patches, dim=0)
352
-
353
- # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
354
- image_features, feature_lens = self.pack_image_features(
355
- image_features,
356
- image_sizes,
357
- image_newline=self.image_newline,
295
+ # ! infer image_num_patches from image_sizes
296
+ image_num_patches = [
297
+ image_size_to_num_patches(
298
+ image_size=imsize,
299
+ grid_pinpoints=self.config.image_grid_pinpoints,
300
+ patch_size=self.config.vision_config.image_size,
358
301
  )
302
+ for imsize in image_sizes
303
+ ]
359
304
 
360
- inputs_embeds = inputs_embeds.to(image_features.dtype)
361
-
362
- if from_vllm_prefill:
363
- self._merge_vllm_multimodal_embeddings(
364
- input_ids, inputs_embeds, image_features, self.config.image_token_index
365
- )
366
- else:
367
- inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
368
- image_features,
369
- feature_lens,
370
- inputs_embeds,
371
- input_ids,
372
- attention_mask,
373
- )
305
+ # figure out if pixel_values is concatenated or stacked
306
+ if pixel_values.dim() == 5:
307
+ # stacking when input is (batch_size, num_patches, num_channels, height, width)
308
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
309
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
310
+ elif pixel_values.dim() != 4:
311
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
312
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
374
313
 
375
- cache_position = torch.arange(0, inputs_embeds.shape[1], dtype=torch.int32).unsqueeze_(0)
314
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
315
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
376
316
 
377
- # pixel_values is not None but is empty ---> text only cases
378
- elif (
379
- pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0 or pixel_values is None
380
- ):
381
- pass
382
-
383
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
384
- # generation with cache
385
- elif pixel_values is not None and input_ids.shape[1] == 1 and past_cached_length is not None:
386
- cache_position = past_cached_length
317
+ if vision_feature_select_strategy == "default":
318
+ selected_image_feature = selected_image_feature[:, 1:]
319
+ elif vision_feature_select_strategy == "full":
320
+ selected_image_feature = selected_image_feature
321
+
322
+ image_features = self.multi_modal_projector(selected_image_feature)
323
+ image_features = torch.split(image_features, image_num_patches, dim=0)
324
+
325
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
326
+ image_features, feature_lens = self.pack_image_features(
327
+ image_features,
328
+ image_sizes,
329
+ image_newline=self.image_newline,
330
+ )
387
331
 
388
- return inputs_embeds, cache_position
332
+ return image_features, feature_lens
389
333
 
390
334
  def forward(
391
335
  self,
392
336
  input_ids: torch.LongTensor = None,
337
+ attention_mask: torch.LongTensor = None,
393
338
  pixel_values: torch.FloatTensor = None,
394
339
  image_sizes: Optional[torch.LongTensor] = None,
395
340
  inputs_embeds: Optional[torch.FloatTensor] = None,
396
341
  vision_feature_layer: Optional[int] = None,
397
342
  vision_feature_select_strategy: Optional[str] = None,
398
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
399
- batch_idx: Optional[int] = None,
400
- past_cached_length: Optional[torch.Tensor] = None,
343
+ cache_position: torch.Tensor = None,
344
+ generate_idx: Optional[torch.Tensor] = None,
401
345
  **kwargs,
402
346
  ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
403
- from_vllm_prefill = isinstance(cache_position, torch.Tensor) and cache_position.shape[-1] > 1
404
- from_hf_generate_prefill = isinstance(input_ids, list)
405
-
406
347
  if inputs_embeds is not None:
407
348
  raise NotImplementedError("Specifying inputs_embeds is not supported.")
408
349
 
409
- if from_hf_generate_prefill:
410
- inputs_embeds = []
411
- batch_size = len(input_ids)
350
+ is_prefill_phase = not generate_idx.bool().all()
412
351
 
352
+ if is_prefill_phase:
413
353
  # Get the number of images in the prompt
414
354
  special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
415
355
  num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
416
356
 
417
357
  # Split images for each prompt
418
- pixel_values = pixel_values.split(num_special_image_tokens, dim=0)
419
- image_sizes = image_sizes.split(num_special_image_tokens, dim=0)
420
-
421
- for b_idx in range(batch_size):
422
- embed, cache_pos = self._embed(
423
- input_ids=input_ids[b_idx],
424
- image_sizes=image_sizes[b_idx] if image_sizes is not None else None,
425
- attention_mask=torch.ones_like(input_ids[b_idx]),
426
- pixel_values=pixel_values[b_idx] if pixel_values is not None else None,
427
- vision_feature_layer=vision_feature_layer,
428
- vision_feature_select_strategy=vision_feature_select_strategy,
429
- cache_position=cache_position[b_idx],
430
- past_cached_length=past_cached_length[b_idx : b_idx + 1],
358
+ if pixel_values is not None and pixel_values.size(0) > 0:
359
+ pixel_values = pixel_values.split(num_special_image_tokens, dim=0)
360
+ image_sizes = image_sizes.split(num_special_image_tokens, dim=0)
361
+
362
+ logits = []
363
+ for b_idx in range(input_ids.shape[0]):
364
+ # Get text_embeds from input_id
365
+ input_id = input_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
366
+ inputs_embed = self.text_embedding(input_id)
367
+
368
+ # If any images in the prompt, get image_embeds and merge with text
369
+ if num_special_image_tokens[b_idx] > 0:
370
+ image_features, feature_lens = self.image_embedding(
371
+ image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
372
+ )
373
+ inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
374
+ image_features,
375
+ feature_lens,
376
+ inputs_embed.to(image_features.dtype),
377
+ input_id,
378
+ torch.ones_like(input_id, dtype=torch.long),
379
+ )
380
+
381
+ # Update generate_idx according to inputs_embed
382
+ generate_idx[b_idx] = inputs_embed.shape[1]
383
+
384
+ logit = self.language_model._forward_prefill(
385
+ inputs_embeds=inputs_embed,
386
+ batch_idx=b_idx,
387
+ cache_position=torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0),
431
388
  )
432
- inputs_embeds.append(embed)
433
- cache_position[b_idx] = cache_pos
434
- past_cached_length[b_idx] += embed.shape[1]
435
-
436
- elif from_vllm_prefill:
437
- inputs_embeds, cache_position = self._embed(
438
- input_ids=input_ids,
439
- image_sizes=image_sizes,
440
- attention_mask=torch.ones_like(input_ids),
441
- pixel_values=pixel_values,
442
- vision_feature_layer=vision_feature_layer,
443
- vision_feature_select_strategy=vision_feature_select_strategy,
444
- cache_position=cache_position,
445
- past_cached_length=past_cached_length,
446
- from_vllm_prefill=from_vllm_prefill,
447
- )
389
+
390
+ logits.append(logit)
391
+
392
+ logits = torch.cat(logits, dim=0)
393
+ outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
394
+
448
395
  else:
449
- # Decoding step
450
- inputs_embeds, cache_position = self._embed(
451
- input_ids=input_ids,
452
- image_sizes=image_sizes,
453
- attention_mask=torch.ones_like(input_ids),
454
- pixel_values=pixel_values,
455
- vision_feature_layer=vision_feature_layer,
456
- vision_feature_select_strategy=vision_feature_select_strategy,
396
+ inputs_embeds = self.text_embedding(input_ids)
397
+
398
+ outputs: RBLNDecoderOnlyOutput = self.language_model(
399
+ inputs_embeds=inputs_embeds,
457
400
  cache_position=cache_position,
458
- past_cached_length=past_cached_length,
401
+ generate_idx=generate_idx,
459
402
  )
460
403
 
461
- outputs: RBLNDecoderOnlyOutput = self.language_model(
404
+ return outputs
405
+
406
+ def vllm_forward(
407
+ self,
408
+ input_ids: torch.LongTensor = None,
409
+ pixel_values: torch.FloatTensor = None,
410
+ image_sizes: Optional[torch.LongTensor] = None,
411
+ inputs_embeds: Optional[torch.FloatTensor] = None,
412
+ vision_feature_layer: Optional[int] = None,
413
+ vision_feature_select_strategy: Optional[str] = None,
414
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
415
+ batch_idx: Optional[int] = None,
416
+ **kwargs,
417
+ ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
418
+ is_prefill = cache_position.shape[-1] > 1
419
+
420
+ if inputs_embeds is not None:
421
+ raise NotImplementedError("Specifying inputs_embeds is not supported.")
422
+
423
+ if is_prefill:
424
+ # Get text_embeds
425
+ inputs_embeds = self.text_embedding(input_ids)
426
+
427
+ # If any images in the prompt, get image_embeds and merge with text
428
+ if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
429
+ image_features, _ = self.image_embedding(
430
+ image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
431
+ )
432
+
433
+ def merge_vllm_multimodal_embeddings(
434
+ input_ids: torch.Tensor,
435
+ inputs_embeds: torch.Tensor,
436
+ multimodal_embeddings: torch.Tensor,
437
+ placeholder_token_id: int,
438
+ ) -> torch.Tensor:
439
+ mask = input_ids == placeholder_token_id
440
+ num_expected_tokens = mask.sum().item()
441
+
442
+ if multimodal_embeddings.shape[0] != num_expected_tokens:
443
+ raise ValueError(
444
+ f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
445
+ f"multimodal tokens to {num_expected_tokens} placeholders"
446
+ )
447
+
448
+ inputs_embeds[mask] = multimodal_embeddings
449
+ return inputs_embeds
450
+
451
+ inputs_embeds = merge_vllm_multimodal_embeddings(
452
+ input_ids, inputs_embeds, image_features, self.config.image_token_index
453
+ )
454
+
455
+ else:
456
+ inputs_embeds = self.text_embedding(input_ids=input_ids)
457
+
458
+ outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
462
459
  inputs_embeds=inputs_embeds,
463
460
  batch_idx=batch_idx,
464
461
  cache_position=cache_position,
465
- past_cached_length=past_cached_length,
466
462
  )
467
463
 
468
464
  return outputs
@@ -73,3 +73,8 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
73
73
  if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
74
74
  return redirect(val)
75
75
  return val
76
+
77
+ @classmethod
78
+ def from_pretrained(cls, *args, **kwargs):
79
+ kwargs.setdefault("trust_remote_code", True)
80
+ return super().from_pretrained(*args, **kwargs)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_qwen2 import RBLNQwen2ForCausalLM
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Callable
27
+
28
+ from transformers import Qwen2ForCausalLM
29
+
30
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
31
+ from .qwen2_architecture import QWEN2Wrapper
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from transformers import PreTrainedModel
36
+
37
+ from ....modeling_config import RBLNConfig
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class RBLNQwen2ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
+ """
44
+ The Llama Model transformer with a language modeling head (linear layer) on top.
45
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
+
47
+ A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
48
+ It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
49
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
50
+ - compiling the resulting graph using the RBLN compiler.
51
+ """
52
+
53
+ @classmethod
54
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
+ rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
+ return QWEN2Wrapper(model, rbln_max_seq_len).eval()
57
+
58
+ def __getattr__(self, __name: str) -> Any:
59
+ def redirect(func):
60
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
+
62
+ val = getattr(Qwen2ForCausalLM, __name)
63
+
64
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
+ return redirect(val)
66
+
67
+ return val
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
26
+
27
+
28
+ class QWEN2Wrapper(DecoderOnlyWrapper):
29
+ pass
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_seq2seq import RBLNModelForSeq2SeqLM