optimum-rbln 0.2.1a0__py3-none-any.whl → 0.2.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. optimum/rbln/__init__.py +3 -10
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +1 -10
  4. optimum/rbln/diffusers/modeling_diffusers.py +1 -10
  5. optimum/rbln/diffusers/models/__init__.py +1 -10
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
  7. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
  9. optimum/rbln/diffusers/models/controlnet.py +1 -10
  10. optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
  12. optimum/rbln/diffusers/models/unets/__init__.py +1 -10
  13. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
  14. optimum/rbln/diffusers/pipelines/__init__.py +1 -10
  15. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
  16. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
  20. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
  24. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
  32. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
  33. optimum/rbln/modeling.py +1 -10
  34. optimum/rbln/modeling_base.py +1 -10
  35. optimum/rbln/modeling_config.py +1 -10
  36. optimum/rbln/ops/__init__.py +1 -10
  37. optimum/rbln/ops/attn.py +9 -18
  38. optimum/rbln/ops/flash_attn.py +5 -14
  39. optimum/rbln/ops/kv_cache_update.py +1 -10
  40. optimum/rbln/transformers/__init__.py +5 -12
  41. optimum/rbln/transformers/modeling_alias.py +1 -14
  42. optimum/rbln/transformers/modeling_generic.py +40 -21
  43. optimum/rbln/transformers/modeling_rope_utils.py +28 -0
  44. optimum/rbln/transformers/models/__init__.py +3 -12
  45. optimum/rbln/transformers/models/auto/__init__.py +1 -10
  46. optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
  47. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
  48. optimum/rbln/transformers/models/bart/__init__.py +1 -10
  49. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
  50. optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
  51. optimum/rbln/transformers/models/bert/__init__.py +2 -11
  52. optimum/rbln/transformers/models/bert/modeling_bert.py +23 -13
  53. optimum/rbln/transformers/models/clip/__init__.py +1 -10
  54. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
  55. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
  56. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +54 -69
  57. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
  58. optimum/rbln/transformers/models/dpt/__init__.py +1 -10
  59. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
  60. optimum/rbln/transformers/models/exaone/__init__.py +1 -10
  61. optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
  62. optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
  63. optimum/rbln/transformers/models/gemma/__init__.py +1 -10
  64. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
  65. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
  66. optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
  67. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
  68. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
  69. optimum/rbln/transformers/models/llama/__init__.py +1 -10
  70. optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
  71. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
  72. optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
  74. optimum/rbln/transformers/models/midm/__init__.py +1 -10
  75. optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
  76. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
  77. optimum/rbln/transformers/models/mistral/__init__.py +1 -10
  78. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
  79. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
  80. optimum/rbln/transformers/models/phi/__init__.py +1 -10
  81. optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
  82. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
  83. optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
  84. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
  85. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
  86. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
  87. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
  88. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +16 -42
  89. optimum/rbln/transformers/models/t5/__init__.py +1 -10
  90. optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
  91. optimum/rbln/transformers/models/t5/t5_architecture.py +30 -16
  92. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
  93. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
  94. optimum/rbln/transformers/models/whisper/__init__.py +1 -10
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
  96. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
  97. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
  99. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
  100. optimum/rbln/transformers/utils/rbln_quantization.py +6 -10
  101. optimum/rbln/utils/__init__.py +1 -10
  102. optimum/rbln/utils/decorator_utils.py +1 -10
  103. optimum/rbln/utils/hub.py +1 -10
  104. optimum/rbln/utils/import_utils.py +1 -10
  105. optimum/rbln/utils/logging.py +1 -10
  106. optimum/rbln/utils/model_utils.py +1 -10
  107. optimum/rbln/utils/runtime_utils.py +1 -10
  108. optimum/rbln/utils/save_utils.py +2 -10
  109. optimum/rbln/utils/submodule.py +1 -10
  110. {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/METADATA +6 -4
  111. optimum_rbln-0.2.1a2.dist-info/RECORD +114 -0
  112. optimum_rbln-0.2.1a2.dist-info/licenses/LICENSE +201 -0
  113. optimum_rbln-0.2.1a0.dist-info/RECORD +0 -114
  114. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +0 -288
  115. {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .gemma_architecture import GemmaWrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_gpt2 import RBLNGPT2LMHeadModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import math
25
16
  from typing import TYPE_CHECKING, Tuple
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .gpt2_architecture import GPT2Wrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_llama import RBLNLlamaForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
25
16
 
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .llama_architecture import LlamaWrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_llava_next import RBLNLlavaNextForConditionalGeneration
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  import logging
26
17
  from pathlib import Path
@@ -233,14 +224,14 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
233
224
 
234
225
  if is_prefill_phase:
235
226
  model_inputs["generate_idx"] = torch.zeros((batch_size, 1), dtype=torch.int32)
227
+ model_inputs.update(
228
+ {
229
+ "pixel_values": pixel_values,
230
+ "image_sizes": image_sizes,
231
+ }
232
+ )
236
233
 
237
- model_inputs.update(
238
- {
239
- "pixel_values": pixel_values,
240
- "image_sizes": image_sizes,
241
- "attention_mask": attention_mask,
242
- }
243
- )
234
+ model_inputs["attention_mask"] = attention_mask
244
235
  return model_inputs
245
236
 
246
237
  def _update_model_kwargs_for_generation(
@@ -266,11 +257,11 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
266
257
 
267
258
  def image_embedding(
268
259
  self,
269
- image_sizes: torch.LongTensor,
260
+ image_sizes: torch.Tensor,
270
261
  pixel_values: torch.FloatTensor,
271
262
  vision_feature_layer: int,
272
263
  vision_feature_select_strategy: str,
273
- ) -> torch.Tensor:
264
+ ):
274
265
  vision_feature_layer = (
275
266
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
276
267
  )
@@ -280,6 +271,23 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
280
271
  else self.config.vision_feature_select_strategy
281
272
  )
282
273
 
274
+ """
275
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
276
+
277
+ Args:
278
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
279
+ The tensors corresponding to the input images.
280
+ image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
281
+ Actual image size of each images (H, W).
282
+ vision_feature_layer (`int`):
283
+ The index of the layer to select the vision feature.
284
+ vision_feature_select_strategy (`str`):
285
+ The feature selection strategy used to select the vision feature from the vision backbone.
286
+ Can be one of `"default"` or `"full"`
287
+ Returns:
288
+ image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
289
+ and are of shape `(num_patches, image_length, embed_dim)`).
290
+ """
283
291
  # ! infer image_num_patches from image_sizes
284
292
  image_num_patches = [
285
293
  image_size_to_num_patches(
@@ -289,10 +297,8 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
289
297
  )
290
298
  for imsize in image_sizes
291
299
  ]
292
-
293
- # figure out if pixel_values is concatenated or stacked
294
300
  if pixel_values.dim() == 5:
295
- # stacking when input is (batch_size, num_patches, num_channels, height, width)
301
+ # stacked if input is (batch_size, num_patches, num_channels, height, width)
296
302
  _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
297
303
  pixel_values = torch.cat(_pixel_values_list, dim=0)
298
304
  elif pixel_values.dim() != 4:
@@ -301,12 +307,10 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
301
307
 
302
308
  image_features = self.vision_tower(pixel_values, output_hidden_states=True)
303
309
  selected_image_feature = image_features.hidden_states[vision_feature_layer]
304
-
305
310
  if vision_feature_select_strategy == "default":
306
311
  selected_image_feature = selected_image_feature[:, 1:]
307
312
  elif vision_feature_select_strategy == "full":
308
313
  selected_image_feature = selected_image_feature
309
-
310
314
  image_features = self.multi_modal_projector(selected_image_feature)
311
315
  image_features = torch.split(image_features, image_num_patches, dim=0)
312
316
 
@@ -314,6 +318,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
314
318
  image_features, feature_lens = self.pack_image_features(
315
319
  image_features,
316
320
  image_sizes,
321
+ vision_feature_select_strategy=vision_feature_select_strategy,
317
322
  image_newline=self.image_newline,
318
323
  )
319
324
 
@@ -330,78 +335,63 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
330
335
  vision_feature_select_strategy: Optional[str] = None,
331
336
  cache_position: torch.Tensor = None,
332
337
  generate_idx: Optional[torch.Tensor] = None,
338
+ batch_idx: Optional[int] = None,
333
339
  **kwargs,
334
340
  ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
341
+ vision_feature_layer = (
342
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
343
+ )
344
+ vision_feature_select_strategy = (
345
+ vision_feature_select_strategy
346
+ if vision_feature_select_strategy is not None
347
+ else self.config.vision_feature_select_strategy
348
+ )
349
+
335
350
  if inputs_embeds is not None:
336
351
  raise NotImplementedError("Specifying inputs_embeds is not supported.")
352
+ inputs_embeds = self.get_input_embeddings()(input_ids)
353
+
354
+ if pixel_values is not None and pixel_values.size(0) > 0:
355
+ image_features, _ = self.image_embedding(
356
+ pixel_values=pixel_values,
357
+ image_sizes=image_sizes,
358
+ vision_feature_layer=vision_feature_layer,
359
+ vision_feature_select_strategy=vision_feature_select_strategy,
360
+ )
361
+
362
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
363
+ n_image_features = image_features.shape[0]
364
+ if n_image_tokens != n_image_features:
365
+ raise ValueError(
366
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
367
+ )
368
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
369
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
370
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
371
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
337
372
 
338
373
  is_prefill_phase = not generate_idx.bool().all()
339
374
 
340
375
  if is_prefill_phase:
341
- # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
342
- # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
343
- # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
344
- legacy_processing = (
345
- (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
346
- ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
347
-
348
- # Get the number of images in the prompt
349
- special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
350
- if legacy_processing:
351
- num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
352
- else:
353
- image_tokens_masks_diff = [
354
- torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
355
- ]
356
- num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
357
-
358
- # Split images for each prompt
359
- if pixel_values is not None and pixel_values.size(0) > 0:
360
- pixel_values = pixel_values.split(num_special_image_tokens, dim=0)
361
- image_sizes = image_sizes.split(num_special_image_tokens, dim=0)
362
-
363
376
  logits = []
364
- for b_idx in range(input_ids.shape[0]):
365
- # Get text_embeds from input_id
366
- input_id = input_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
367
- inputs_embed = self.text_embedding(input_id)
368
-
369
- # If any images in the prompt, get image_embeds and merge with text
370
- if num_special_image_tokens[b_idx] > 0:
371
- image_features, feature_lens = self.image_embedding(
372
- image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
373
- )
374
- if legacy_processing:
375
- inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
376
- image_features,
377
- feature_lens,
378
- inputs_embed.to(image_features.dtype),
379
- input_id,
380
- torch.ones_like(input_id, dtype=torch.long),
381
- )
382
- else:
383
- special_image_mask = (
384
- (input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
385
- )
386
- inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
387
-
388
- # Update generate_idx according to inputs_embed
389
- generate_idx[b_idx] = inputs_embed.shape[1]
390
-
377
+ batch_size = input_ids.shape[0]
378
+ inputs_embeds = [inputs_embeds[i : i + 1, attention_mask[i].bool()] for i in range(batch_size)]
379
+ for batch_idx in range(batch_size):
380
+ generate_idx[batch_idx] = inputs_embeds[batch_idx].shape[-2]
391
381
  logit = self.language_model._forward_prefill(
392
- inputs_embeds=inputs_embed,
393
- batch_idx=b_idx,
394
- cache_position=torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0),
382
+ inputs_embeds=inputs_embeds[batch_idx],
383
+ batch_idx=batch_idx,
384
+ cache_position=torch.arange(
385
+ 0,
386
+ generate_idx[batch_idx].item(),
387
+ dtype=torch.int32,
388
+ ).unsqueeze(0),
395
389
  )
396
390
 
397
391
  logits.append(logit)
398
-
399
392
  logits = torch.cat(logits, dim=0)
400
393
  outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
401
-
402
394
  else:
403
- inputs_embeds = self.text_embedding(input_ids)
404
-
405
395
  outputs: RBLNDecoderOnlyOutput = self.language_model(
406
396
  inputs_embeds=inputs_embeds,
407
397
  cache_position=cache_position,
@@ -410,8 +400,8 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
410
400
 
411
401
  return outputs
412
402
 
413
- # Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
414
- def pack_image_features(self, image_features, image_sizes, image_newline=None):
403
+ # Almost copied from : https://github.com/huggingface/transformers/blob/6b550462139655d488d4c663086a63e98713c6b9/src/transformers/models/llava_next/modeling_llava_next.py
404
+ def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
415
405
  """
416
406
  Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
417
407
 
@@ -420,6 +410,8 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
420
410
  List of image feature tensor, each contains all the visual feature of all patches.
421
411
  image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
422
412
  Actual image size of each images (H, W).
413
+ vision_feature_select_strategy (`str`)
414
+ The feature selection strategy used to select the vision feature from the vision backbone.
423
415
  image_newline (`torch.Tensor` of shape `(embed_dim)`)
424
416
  New line embedding vector.
425
417
  Returns:
@@ -434,9 +426,15 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
434
426
  base_image_feature = image_feature[0]
435
427
  image_feature = image_feature[1:]
436
428
  height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
437
- if height * width != base_image_feature.shape[0]:
429
+
430
+ if vision_feature_select_strategy == "default":
431
+ expected_num_patches = height * width
432
+ elif vision_feature_select_strategy == "full":
433
+ expected_num_patches = height * width + 1
434
+ if expected_num_patches != base_image_feature.shape[0]:
438
435
  raise ValueError("The number of patches is not consistent with the image size.")
439
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
436
+
437
+ num_patch_height, num_patch_width = get_anyres_image_grid_shape(
440
438
  image_sizes[image_idx],
441
439
  self.config.image_grid_pinpoints,
442
440
  self.config.vision_config.image_size,
@@ -449,7 +447,9 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
449
447
  image_feature = torch.cat(
450
448
  (
451
449
  image_feature,
452
- image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
450
+ image_newline[:, None, None]
451
+ .expand(*image_feature.shape[:-1], 1)
452
+ .to(image_feature.device, image_feature.dtype),
453
453
  ),
454
454
  dim=-1,
455
455
  )
@@ -498,7 +498,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
498
498
  return height // patch_size, width // patch_size
499
499
 
500
500
 
501
- # Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
501
+ # Almost copied from : https://github.com/huggingface/transformers/blob/1feebb5b4150882deabddd190a541f336f3be817/src/transformers/models/llava_next/modeling_llava_next.py#L115C1-L152C1
502
502
  def unpad_image(tensor, original_size):
503
503
  """
504
504
  Unpads a PyTorch tensor of a padded and resized image.
@@ -512,6 +512,12 @@ def unpad_image(tensor, original_size):
512
512
  Returns:
513
513
  `torch.Tensor`: The unpadded image tensor.
514
514
  """
515
+ if not isinstance(original_size, (list, tuple)):
516
+ if not isinstance(original_size, (torch.Tensor, np.ndarray)):
517
+ raise TypeError(
518
+ f"image_size invalid type: {type(original_size)} not valid, should be either list, tuple, np.ndarray or tensor"
519
+ )
520
+ original_size = original_size.tolist()
515
521
  original_height, original_width = original_size
516
522
  current_height, current_width = tensor.shape[1:]
517
523
 
@@ -520,12 +526,12 @@ def unpad_image(tensor, original_size):
520
526
 
521
527
  if original_aspect_ratio > current_aspect_ratio:
522
528
  scale_factor = current_width / original_width
523
- new_height = int(original_height * scale_factor)
529
+ new_height = int(round(original_height * scale_factor, 7))
524
530
  padding = (current_height - new_height) // 2
525
531
  unpadded_tensor = tensor[:, padding : current_height - padding, :]
526
532
  else:
527
533
  scale_factor = current_height / original_height
528
- new_width = int(original_width * scale_factor)
534
+ new_width = int(round(original_width * scale_factor, 7))
529
535
  padding = (current_width - new_width) // 2
530
536
  unpadded_tensor = tensor[:, :, padding : current_width - padding]
531
537
 
@@ -577,7 +583,7 @@ def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
577
583
  Calculate the number of patches after the preprocessing for images of any resolution.
578
584
 
579
585
  Args:
580
- image_size (`Union[torch.LongTensor, np.ndarray, Tuple[int, int]):
586
+ image_size (`torch.LongTensor` or `np.ndarray` or `Tuple[int, int]`):
581
587
  The size of the input image in the format (height, width). ?
582
588
  grid_pinpoints (`List`):
583
589
  A list containing possible resolutions. Each item in the list should be a tuple or list
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import os
25
16
  from os import environ
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import math
25
16
  from typing import TYPE_CHECKING, Tuple
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from transformers import AutoModelForCausalLM
25
16
 
26
17
  from ....utils import logging
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_mistral import RBLNMistralForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
25
16
 
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .mistral_architecture import MistralForCausalLMWrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_phi import RBLNPhiForCausalLM