optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -26,15 +26,16 @@ from transformers.modeling_utils import no_init_weights
26
26
  from ....configuration_utils import RBLNCompileConfig
27
27
  from ....modeling import RBLNModel
28
28
  from ....utils.logging import get_logger
29
+ from ....utils.runtime_utils import is_compiler_supports_buffer_resize
29
30
  from ...modeling_attention_utils import (
30
31
  RBLNDecoderOnlyFlashAttentionMixin,
31
32
  set_default_values,
32
33
  validate_attention_method,
33
34
  validate_sliding_window,
34
35
  )
35
- from ...modeling_outputs import RBLNDecoderOnlyOutput
36
+ from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
36
37
  from ...utils.rbln_quantization import get_quantized_model
37
- from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
+ from .configuration_decoderonly import KVCacheMeta, RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
39
  from .decoderonly_architecture import DecoderOnlyWrapper
39
40
  from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
41
  from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
@@ -88,8 +89,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
88
89
  def setup_runtime(self):
89
90
  # Initialize resources to be used across Runtime instances (prefill and decode phases)
90
91
  page_table_manager = RBLNPageTableManager(self.rbln_config)
91
- dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
92
- out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
92
+ if self.rbln_config.use_position_ids:
93
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=self.dtype)
94
+ else:
95
+ dec_attn_mask = torch.zeros(
96
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype
97
+ )
93
98
 
94
99
  common_kwargs = {
95
100
  "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
@@ -97,12 +102,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
97
102
  "dec_attn_mask": dec_attn_mask,
98
103
  "page_table_manager": page_table_manager,
99
104
  "rbln_config": self.rbln_config,
105
+ "config": self.config,
100
106
  }
101
107
  self.prefill_decoder = RBLNRuntimeModel(
102
108
  runtime=self.model[0],
103
109
  phase="prefill",
104
110
  batch_size=self.rbln_config.batch_size,
105
- out_buffers=out_buffers,
111
+ logits_last_dim=self.logits_last_dim,
106
112
  **common_kwargs,
107
113
  )
108
114
  if self.can_generate():
@@ -119,12 +125,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
119
125
  self.decoder = self.decoders[self.rbln_config.batch_size]
120
126
 
121
127
  @property
122
- def prefill_output_size(self):
123
- return (
124
- 1,
125
- self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
- self.config.hidden_size,
127
- )
128
+ def logits_last_dim(self):
129
+ return self.config.hidden_size
128
130
 
129
131
  @classmethod
130
132
  def get_quantized_model(
@@ -216,7 +218,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
216
218
  return self.rbln_config.kvcache_num_blocks
217
219
 
218
220
  @classmethod
219
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
221
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
220
222
  return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
221
223
 
222
224
  @classmethod
@@ -229,7 +231,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
229
231
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
230
232
  quantization=None,
231
233
  phase: str = "prefill",
232
- ):
234
+ ) -> rebel.RBLNCompiledModel:
233
235
  try:
234
236
  wrapped_model.phase = phase
235
237
  if quantization:
@@ -251,28 +253,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
251
253
  quantization.maybe_reset_quantization_env()
252
254
 
253
255
  @classmethod
254
- def _get_compile_context(
255
- cls,
256
- compile_config: RBLNCompileConfig,
257
- example_inputs: List[torch.Tensor],
258
- ):
256
+ def _get_compile_context(cls, compile_config: RBLNCompileConfig, example_inputs: List[torch.Tensor]):
259
257
  context = CompileContext(use_weight_sharing=True)
260
258
 
261
259
  # Mark static tensors (self kv states)
262
260
  static_tensors = {}
263
- idx = 0
264
261
  for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
265
262
  if "past_key_values" in name:
266
263
  static_tensors[name] = tensor
267
- context.mark_static_address(tensor, f"kv_cache_{idx}")
268
- idx += 1
264
+ context.mark_static_address(tensor, name)
269
265
 
270
266
  return context, static_tensors
271
267
 
272
268
  @classmethod
273
269
  @torch.inference_mode()
274
270
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
275
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
271
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
276
272
  prefill_compile_config = rbln_config.compile_cfgs[0]
277
273
 
278
274
  # Here we use meta tensor, for the memory efficiency.
@@ -280,7 +276,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
280
276
  prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
281
277
  context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
282
278
 
283
- compiled_models = {}
279
+ compiled_models: dict[str, rebel.RBLNCompiledModel] = {}
284
280
  compiled_models["prefill"] = cls._compile_model(
285
281
  wrapped_model,
286
282
  prefill_compile_config,
@@ -306,14 +302,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
306
302
  )
307
303
  compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
308
304
 
309
- # check if the memory is enough to have additional blocks
310
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
311
- if rbln_config.kvcache_num_blocks < required_num_blocks:
312
- cls.maybe_suggest_kvcache_num_blocks(
313
- compiled_models=compiled_models,
314
- model_config=model.config,
315
- rbln_config=rbln_config,
316
- )
305
+ if rbln_config.is_auto_num_blocks:
306
+ if not is_compiler_supports_buffer_resize():
307
+ raise RuntimeError("`kvcache_num_blocks` must be set.")
308
+ cls.set_kvcache_num_blocks_after_compilation(compiled_models, rbln_config)
317
309
 
318
310
  return compiled_models
319
311
 
@@ -329,8 +321,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
329
321
  return model
330
322
 
331
323
  @classmethod
332
- def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
333
- return use_local_attention
324
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True, logits_to_keep: int = None):
325
+ return is_prefill and (use_local_attention or logits_to_keep == 1)
334
326
 
335
327
  @classmethod
336
328
  def get_input_info(
@@ -340,16 +332,16 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
340
332
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
341
333
  model_config: PretrainedConfig,
342
334
  ):
343
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
335
+ num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
344
336
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
345
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
346
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
337
+ num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
338
+ hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
347
339
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
348
340
  is_prefill = query_length > 1
349
341
 
350
342
  input_info = []
351
343
  if rbln_config.use_inputs_embeds:
352
- input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
344
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.dtype))
353
345
  else:
354
346
  input_info.append(("input_ids", [batch_size, query_length], "int64"))
355
347
 
@@ -363,15 +355,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
363
355
  if rbln_config.use_local_attention:
364
356
  input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
365
357
 
366
- if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
358
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill, rbln_config.logits_to_keep):
367
359
  input_info.append(("query_position", [], "int16"))
368
360
 
369
361
  if rbln_config.use_attention_mask:
370
362
  if rbln_config.use_position_ids:
371
- input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
363
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.dtype))
372
364
  else:
373
365
  input_info.append(
374
- ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
366
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.dtype)
375
367
  )
376
368
 
377
369
  if rbln_config.use_position_ids:
@@ -380,29 +372,36 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
380
372
  if rbln_config.use_lora:
381
373
  input_info.append(("lora_int_ids", [batch_size], "int32"))
382
374
 
383
- kvcache_dtype = rbln_config.torch_dtype
384
- if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
385
- kvcache_dtype = "float8_e4m3fn"
375
+ if len(rbln_config.kvcache_metas) > 0:
376
+ # Meta is already set, use it
377
+ input_info.extend(
378
+ [
379
+ (kvcache_meta.name, kvcache_meta.compile_shape, kvcache_meta.dtype)
380
+ for kvcache_meta in rbln_config.kvcache_metas
381
+ ]
382
+ )
386
383
 
387
- global_kvcache_shape = [
388
- rbln_config.kvcache_num_blocks,
389
- num_key_value_heads,
390
- rbln_config.kvcache_block_size,
391
- head_dim,
392
- ]
393
- local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
394
- input_info.extend(
395
- [
396
- (
397
- f"past_key_values_{i}",
398
- local_kvcache_shape
399
- if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
400
- else global_kvcache_shape,
401
- kvcache_dtype,
384
+ else:
385
+ kvcache_dtype = rbln_config.dtype
386
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
387
+ kvcache_dtype = "float8_e4m3fn"
388
+
389
+ kvcache_metas = []
390
+ for i in range(num_hidden_layers * 2):
391
+ layer_idx = i // 2
392
+ name = f"past_key_values_{i}"
393
+ kvcache_meta = KVCacheMeta.make(
394
+ name,
395
+ layer_idx,
396
+ num_key_value_heads,
397
+ head_dim,
398
+ RBLNCompileConfig.normalize_dtype(kvcache_dtype),
399
+ rbln_config,
402
400
  )
403
- for i in range(num_hidden_layers * 2)
404
- ]
405
- )
401
+ kvcache_metas.append(kvcache_meta)
402
+ input_info.append((name, kvcache_meta.compile_shape, kvcache_meta.dtype))
403
+
404
+ rbln_config.kvcache_metas.extend(kvcache_metas)
406
405
 
407
406
  return input_info
408
407
 
@@ -439,10 +438,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
439
438
  # Returns:
440
439
  # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
441
440
 
442
- raise NotImplementedError(
443
- "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
444
- "See method docstring for required configuration details."
441
+ rbln_config.sliding_window = model_config.sliding_window
442
+ sliding_window_layers = []
443
+
444
+ for i in range(model_config.num_hidden_layers):
445
+ if hasattr(model_config, "layer_types"):
446
+ if model_config.layer_types[i] == "sliding_attention":
447
+ sliding_window_layers.append(i)
448
+ else:
449
+ sliding_window_layers.append(i)
450
+
451
+ rbln_config.sliding_window_layers = sliding_window_layers
452
+
453
+ rbln_config.cache_impl = (
454
+ "sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
445
455
  )
456
+ return rbln_config
446
457
 
447
458
  @classmethod
448
459
  def _update_attention_config(
@@ -462,58 +473,40 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
462
473
  max_seq_len=rbln_config.max_seq_len,
463
474
  )
464
475
 
465
- num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
466
-
467
- # Update kvcache_num_blocks based on the attention implementation.
476
+ # Validate kvcache_num_blocks based on the number of full blocks required.
477
+ # Eager mode restriction:
478
+ # - num_blocks must be at least equal to the batch size
479
+ # Flash attention restriction:
480
+ # - num_blocks must be at least equal to (max_seq_len // kvcache_block_size) + 1
481
+ # - num_blocks must be no greater than the number of full blocks.
468
482
  if rbln_config.attn_impl == "flash_attn":
469
- estimated_max_num_blocks = cls.get_maximum_num_blocks(
470
- config=model_config,
471
- tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
472
- kvcache_block_size=rbln_config.kvcache_block_size,
473
- nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
474
- n_model_params=sum(p.numel() for p in model.parameters()),
475
- num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
476
- )
483
+ if rbln_config.is_auto_num_blocks:
484
+ # Do nothing
485
+ pass
477
486
 
478
- if rbln_config.kvcache_num_blocks is None:
479
- if estimated_max_num_blocks < num_full_blocks:
480
- # lower bound of the number of blocks for flash attention.
481
- min_blocks_for_flash = min(
482
- rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
487
+ else:
488
+ if rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
489
+ logger.warning(
490
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
491
+ f" than the required number of blocks ({rbln_config.num_full_blocks})."
492
+ "This can cause a failure during model compilation."
493
+ )
494
+ elif rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
495
+ raise ValueError(
496
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is less"
497
+ f" than the minimum number of blocks ({rbln_config.num_min_blocks})."
483
498
  )
484
- if min_blocks_for_flash > estimated_max_num_blocks:
485
- # NOTE: Just try to compile with lower bound of blocks for flash attention.
486
- # Even if it's larger than the estimated maximum number of blocks.
487
- rbln_config.kvcache_num_blocks = min_blocks_for_flash
488
- else:
489
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
490
- rbln_config.kvcache_num_blocks = estimated_max_num_blocks
491
-
492
- if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
493
- raise RuntimeError(
494
- f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
495
- "Ensure the number of blocks is at least equal to the batch size."
496
- )
497
- else:
498
- rbln_config.kvcache_num_blocks = num_full_blocks
499
- elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
500
- logger.warning(
501
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
502
- f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
503
- "This can cause a failure during model compilation."
504
- )
505
499
  else:
506
- if rbln_config.kvcache_num_blocks is None:
507
- rbln_config.kvcache_num_blocks = num_full_blocks
508
- elif rbln_config.kvcache_num_blocks > num_full_blocks:
500
+ if rbln_config.is_auto_num_blocks:
501
+ # Eager attention should use fixed number of blocks.
502
+ rbln_config.kvcache_num_blocks = rbln_config.num_full_blocks
503
+ elif rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
509
504
  logger.warning(
510
505
  f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
511
- f" than the required number of blocks ({num_full_blocks})."
506
+ f" than the required number of blocks ({rbln_config.num_full_blocks})."
512
507
  "This can cause a failure during model compilation."
513
508
  )
514
509
 
515
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
516
-
517
510
  return rbln_config
518
511
 
519
512
  @classmethod
@@ -531,8 +524,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
531
524
  if rbln_config.max_seq_len is None:
532
525
  raise ValueError("`max_seq_len` should be specified.")
533
526
 
534
- if getattr(model_config, "sliding_window", None) is not None and getattr(
535
- model_config, "use_sliding_window", True
527
+ layer_types = getattr(model_config, "layer_types", None)
528
+ all_full_attention = layer_types is not None and all(t == "full_attention" for t in layer_types)
529
+
530
+ if (
531
+ getattr(model_config, "sliding_window", None) is not None
532
+ and getattr(model_config, "use_sliding_window", True)
533
+ and not all_full_attention
536
534
  ):
537
535
  rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
538
536
  if rbln_config.sliding_window is not None:
@@ -608,34 +606,66 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
608
606
  input_ids: Optional[torch.LongTensor] = None,
609
607
  inputs_embeds: Optional[torch.Tensor] = None,
610
608
  attention_mask: Optional[torch.LongTensor] = None,
609
+ position_ids: Optional[torch.Tensor] = None,
611
610
  position_embed: Optional[torch.Tensor] = None,
611
+ output_hidden_states: Optional[bool] = None,
612
612
  **kwargs,
613
- ) -> Tuple[torch.FloatTensor]:
613
+ ) -> BaseModelOutputWithPast:
614
+ """
615
+ Args:
616
+ input_ids (torch.LongTensor, optional): The input IDs to the model.
617
+ inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
618
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
619
+ kwargs (dict[str, Any], optional): Additional keyword arguments.
620
+
621
+ Returns:
622
+ Dataclass containing the last hidden states of the model.
623
+ """
614
624
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
615
625
  batch_size = inputs.shape[0]
626
+ position_embed = kwargs.get("position_embed", None)
616
627
 
617
628
  if batch_size != self.rbln_config.batch_size:
618
629
  raise ValueError(
619
630
  f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
620
631
  )
632
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
621
633
 
622
634
  all_last_hidden_states = []
635
+ all_hidden_states = (
636
+ tuple(
637
+ torch.zeros(
638
+ self.rbln_config.batch_size,
639
+ inputs.shape[1],
640
+ self.config.hidden_size,
641
+ dtype=self.rbln_config.dtype,
642
+ )
643
+ for _ in range(self.config.num_hidden_layers + 1)
644
+ )
645
+ if output_hidden_states
646
+ else None
647
+ )
623
648
  for b_idx in range(self.rbln_config.batch_size):
624
649
  query_length = (
625
650
  attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
626
651
  )
627
652
  cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
628
- last_hidden_states = self.prefill_decoder(
629
- inputs[b_idx : b_idx + 1],
653
+ outputs = self.prefill_decoder(
654
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
655
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
630
656
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
657
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
631
658
  position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
632
659
  cache_position=cache_position,
633
660
  batch_idx=b_idx,
634
- ).logits
635
- all_last_hidden_states.append(last_hidden_states)
661
+ )
662
+ all_last_hidden_states.append(outputs.logits)
663
+ if self.rbln_config.output_hidden_states:
664
+ for l_idx in range(self.config.num_hidden_layers + 1):
665
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
636
666
 
637
667
  last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
638
- return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
668
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)
639
669
 
640
670
 
641
671
  class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
@@ -648,6 +678,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
648
678
  1. Converting pre-trained transformer models to RBLN-optimized format
649
679
  2. Handling the compilation process for RBLN devices
650
680
  3. Managing inference operations for causal language modeling
681
+
651
682
  This class inherits from RBLNModel and implements specific methods required for
652
683
  decoder-only architectures and causal language modeling tasks.
653
684
 
@@ -661,16 +692,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
661
692
  auto_model_class = AutoModelForCausalLM
662
693
 
663
694
  @property
664
- def prefill_output_size(self):
665
- return (
666
- 1,
667
- self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
668
- self.config.vocab_size,
669
- )
670
-
671
- @classmethod
672
- def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
673
- return is_prefill
695
+ def logits_last_dim(self):
696
+ return self.config.vocab_size
674
697
 
675
698
  def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
676
699
  if isinstance(lora_int_ids, int):
@@ -731,6 +754,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
731
754
  token_type_ids: Optional[torch.Tensor] = None,
732
755
  lora_int_ids: Optional[torch.Tensor] = None,
733
756
  return_dict: Optional[torch.Tensor] = None,
757
+ output_hidden_states: Optional[bool] = None,
734
758
  **kwargs,
735
759
  ) -> Tuple[torch.FloatTensor]:
736
760
  # Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
@@ -754,24 +778,48 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
754
778
  )
755
779
  padded_cache_lengths = torch.zeros_like(generate_idx)
756
780
 
781
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
782
+
757
783
  # Prefill
758
784
  if cache_position is None:
759
785
  logits = []
760
786
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
761
787
  batch_size = inputs.shape[0]
788
+ input_len = inputs.shape[1]
789
+ if batch_size > self.rbln_config.batch_size:
790
+ raise ValueError(
791
+ f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
792
+ )
793
+ if input_len > self.rbln_config.max_seq_len:
794
+ raise ValueError(
795
+ f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
796
+ )
797
+
798
+ all_hidden_states = (
799
+ tuple(
800
+ torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.dtype)
801
+ for _ in range(self.config.num_hidden_layers + 1)
802
+ )
803
+ if self.rbln_config.output_hidden_states
804
+ else None
805
+ )
762
806
  for b_idx in range(batch_size):
763
807
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
764
- output = self.prefill_decoder(
808
+ outputs = self.prefill_decoder(
765
809
  input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
766
810
  inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
767
811
  attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
812
+ position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
768
813
  cache_position=cache_position,
769
814
  batch_idx=b_idx,
770
815
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
771
816
  lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
772
817
  )
773
- padded_cache_lengths[b_idx] += output.padded_cache_lengths
774
- logits.append(output.logits)
818
+ padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
819
+ logits.append(outputs.logits)
820
+ if self.rbln_config.output_hidden_states:
821
+ for l_idx in range(self.config.num_hidden_layers + 1):
822
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
775
823
  logits = torch.cat(logits, dim=0)
776
824
  # Decoder
777
825
  else:
@@ -783,17 +831,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
783
831
  f"Available batch sizes are: {list(self.decoders.keys())}. "
784
832
  f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
785
833
  )
786
- logits = self.decoders[batch_size](
834
+ if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
835
+ raise ValueError(
836
+ f"Cache position exceeds the maximum sequence length.\n"
837
+ f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
838
+ f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
839
+ f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
840
+ f"or `max_length` in the generation config."
841
+ )
842
+
843
+ outputs = self.decoders[batch_size](
787
844
  input_ids=input_ids,
788
845
  inputs_embeds=inputs_embeds,
789
846
  cache_position=cache_position,
790
847
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
791
848
  lora_int_ids=lora_int_ids,
792
- ).logits
849
+ )
850
+ logits = outputs.logits
851
+ all_hidden_states = outputs.hidden_states
793
852
 
794
853
  if not return_dict:
795
- return logits, generate_idx, padded_cache_lengths
854
+ return logits, generate_idx, padded_cache_lengths, all_hidden_states
796
855
  else:
797
856
  return RBLNDecoderOnlyOutput(
798
- logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
857
+ logits=logits,
858
+ generate_idx=generate_idx,
859
+ padded_cache_lengths=padded_cache_lengths,
860
+ hidden_states=all_hidden_states,
799
861
  )
@@ -13,6 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import DepthEstimatorOutput
20
+
16
21
  from ...modeling_generic import RBLNModelForDepthEstimation
17
22
 
18
23
 
@@ -23,3 +28,15 @@ class RBLNDepthAnythingForDepthEstimation(RBLNModelForDepthEstimation):
23
28
  This class provides hardware-accelerated inference for Depth Anything V2
24
29
  models on RBLN devices, providing the most capable monocular depth estimation (MDE) model.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
33
+ """
34
+ Forward pass for the RBLN-optimized DepthAnythingForDepthEstimation model.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
38
+
39
+ Returns:
40
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
41
+ """
42
+ return super().forward(pixel_values, **kwargs)
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import QuestionAnsweringModelOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForQuestionAnswering
16
21
 
17
22
 
@@ -25,3 +30,22 @@ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
25
30
  """
26
31
 
27
32
  rbln_model_input_names = ["input_ids", "attention_mask"]
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: Optional[torch.Tensor] = None,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ **kwargs,
39
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
40
+ """
41
+ Forward pass for the RBLN-optimized DistilBERT model for question answering tasks.
42
+
43
+ Args:
44
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
45
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
46
+
47
+ Returns:
48
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
49
+ """
50
+
51
+ return super().forward(input_ids, attention_mask, **kwargs)
@@ -13,6 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import DepthEstimatorOutput
20
+
16
21
  from ...modeling_generic import RBLNModelForDepthEstimation
17
22
 
18
23
 
@@ -23,3 +28,15 @@ class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
23
28
  This class provides hardware-accelerated inference for DPT (Dense Prediction Transformer)
24
29
  models on RBLN devices, supporting monocular depth estimation from single images.
25
30
  """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
33
+ """
34
+ Forward pass for the RBLN-optimized DPT model.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
38
+
39
+ Returns:
40
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
41
+ """
42
+ return super().forward(pixel_values, **kwargs)