optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,6 @@ from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
23
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
- from ...utils.rbln_quantization import RBLNQuantizationConfig
25
24
  from .configuration_lora import RBLNLoRAConfig
26
25
  from .lora_architecture import LoRALinear
27
26
 
@@ -76,8 +75,8 @@ class DecoderOnlyWrapper(nn.Module):
76
75
  f" or equal to max_seq_len({rbln_config.max_seq_len})!"
77
76
  )
78
77
 
79
- self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
80
- self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
78
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len, use_rotary_emb)
79
+ self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
81
80
  self._phase = "prefill"
82
81
 
83
82
  def get_rotary_emb(self, max_seq_len):
@@ -104,7 +103,7 @@ class DecoderOnlyWrapper(nn.Module):
104
103
  def get_rbln_causal_lm_class(self):
105
104
  return DecoderOnlyForCausalLM
106
105
 
107
- def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
106
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int, use_rotary_emb: bool):
108
107
  new_layers = []
109
108
  for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
110
109
  is_sliding = layer_idx in self.rbln_config.sliding_window_layers
@@ -119,6 +118,7 @@ class DecoderOnlyWrapper(nn.Module):
119
118
  new_layers,
120
119
  self.rbln_config,
121
120
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
121
+ use_rotary_emb=use_rotary_emb,
122
122
  )
123
123
 
124
124
  if self.is_causal_lm:
@@ -145,8 +145,11 @@ class DecoderOnlyWrapper(nn.Module):
145
145
  local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
146
146
  query_position = (
147
147
  args.pop(0)
148
- # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
149
- if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
148
+ # query_position usage: prefill & (logits_to_keep == 1 or use_local_attention)
149
+ if (
150
+ "prefill" in self.phase
151
+ and (self.rbln_config.logits_to_keep == 1 or self.rbln_config.use_local_attention)
152
+ )
150
153
  else None
151
154
  )
152
155
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
@@ -203,7 +206,7 @@ class DecoderOnlyWrapper(nn.Module):
203
206
  rotary_emb,
204
207
  ) = self.prepare_forward_args(*args)
205
208
 
206
- logit = self.model(
209
+ logits, all_hidden_states = self.model(
207
210
  input_ids=input_ids,
208
211
  inputs_embeds=inputs_embeds,
209
212
  attention_mask=attention_mask,
@@ -215,9 +218,13 @@ class DecoderOnlyWrapper(nn.Module):
215
218
  global_block_tables=global_block_tables,
216
219
  local_block_tables=local_block_tables,
217
220
  lora_int_id=lora_int_id,
221
+ output_hidden_states=self.rbln_config.output_hidden_states,
218
222
  )
219
223
 
220
- return logit
224
+ if self.rbln_config.output_hidden_states:
225
+ return logits, all_hidden_states
226
+ else:
227
+ return logits
221
228
 
222
229
 
223
230
  class DecoderOnlyForCausalLM(nn.Module):
@@ -237,7 +244,6 @@ class DecoderOnlyForCausalLM(nn.Module):
237
244
 
238
245
  Attributes:
239
246
  config: Configuration from the original causal language model
240
- _original_mod: Reference to the original model for components like lm_head
241
247
  model: RBLN-optimized decoder model instance
242
248
  _phase: Current processing phase ("prefill" or "decode")
243
249
  """
@@ -245,10 +251,9 @@ class DecoderOnlyForCausalLM(nn.Module):
245
251
  def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
246
252
  super().__init__()
247
253
  self.config = causal_lm.config
248
- self._original_mod = causal_lm
249
254
  self.model = model
250
255
  self._phase = "prefill"
251
- self.lm_head = self._original_mod.lm_head
256
+ self.lm_head = causal_lm.lm_head
252
257
 
253
258
  @property
254
259
  def phase(self):
@@ -272,9 +277,10 @@ class DecoderOnlyForCausalLM(nn.Module):
272
277
  global_block_tables: Optional[torch.Tensor] = None,
273
278
  local_block_tables: Optional[torch.Tensor] = None,
274
279
  lora_int_id: Optional[torch.Tensor] = None,
280
+ output_hidden_states: Optional[bool] = None,
275
281
  ):
276
282
  # outputs
277
- hidden_states = self.model(
283
+ hidden_states, all_hidden_states = self.model(
278
284
  input_ids=input_ids,
279
285
  inputs_embeds=inputs_embeds,
280
286
  attention_mask=attention_mask,
@@ -286,9 +292,10 @@ class DecoderOnlyForCausalLM(nn.Module):
286
292
  global_block_tables=global_block_tables,
287
293
  local_block_tables=local_block_tables,
288
294
  lora_int_id=lora_int_id,
295
+ output_hidden_states=output_hidden_states,
289
296
  )
290
297
 
291
- if "prefill" in self.phase:
298
+ if "prefill" in self.phase and query_position is not None:
292
299
  hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
293
300
 
294
301
  logits = self.lm_head(hidden_states)
@@ -299,7 +306,7 @@ class DecoderOnlyForCausalLM(nn.Module):
299
306
  logits = torch.tanh(logits)
300
307
  logits = logits * self.config.final_logit_softcapping
301
308
 
302
- return logits
309
+ return logits, all_hidden_states
303
310
 
304
311
 
305
312
  class DecoderOnlyModel(nn.Module):
@@ -312,20 +319,35 @@ class DecoderOnlyModel(nn.Module):
312
319
  use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
313
320
 
314
321
  Attributes:
315
- _original_mod: Reference to original Huggingface model
316
322
  layers: ModuleList of RBLN-optimized transformer layers
317
323
  _phase: Current processing phase ("prefill" or "decode")
318
324
  """
319
325
 
326
+ _EMBEDDING_ATTRS = ["embed_tokens", "wte"]
327
+ _POSITION_ATTRS = ["embed_positions", "wpe"]
328
+ _LAYERNORM_ATTRS = ["norm", "final_layer_norm", "final_layernorm", "ln_f", "layer_norm"]
329
+ _PRE_FF_LAYERNORM_ATTRS = None
330
+ _POST_FF_LAYERNORM_ATTRS = None
331
+
320
332
  def __init__(
321
333
  self,
322
334
  model,
323
335
  layers: List["DecoderOnlyLayer"],
324
336
  rbln_config: "RBLNDecoderOnlyModelConfig",
325
337
  use_learned_pos_emb=None,
338
+ use_rotary_emb=True,
326
339
  ):
327
340
  super().__init__()
328
- self._original_mod = model
341
+ self.config = model.config
342
+ # Keep commonly-used original submodules registered on this wrapper so their weights
343
+ # are preserved in state_dict even if the original model object is not kept.
344
+ # Different HF model families use different attribute names; we register what we can
345
+ # and allow subclasses to override getters when needed.
346
+ self.embed_tokens = _get_attr_from_candidates(model, self._EMBEDDING_ATTRS)
347
+ # hasattr(model, "rotary_emb") is workaround for Qwen2VL
348
+ if not (use_rotary_emb or hasattr(model, "rotary_emb")):
349
+ self.embed_positions = _get_attr_from_candidates(model, self._POSITION_ATTRS)
350
+ self.norm = _get_attr_from_candidates(model, self._LAYERNORM_ATTRS)
329
351
  self.layers = nn.ModuleList(layers)
330
352
  self.rbln_config = rbln_config
331
353
  self._phase = "prefill"
@@ -364,26 +386,28 @@ class DecoderOnlyModel(nn.Module):
364
386
  cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
365
387
  return cache_pos_for_partitions
366
388
 
367
- def get_local_cache_positions(self, position_ids, query_position):
368
- max_cache_len = self._original_mod.config.sliding_window
389
+ def get_swa_custom_op_args(self, position_ids, query_position):
390
+ max_cache_len = self.config.sliding_window
369
391
  valid_input_len = 1 if query_position is None else query_position + 1
370
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
392
+ cache_seq_len = torch.clamp(position_ids.to(torch.int32), max=max_cache_len)[:, :1] # past seen tokens
371
393
  cache_offset = (
372
394
  torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
373
395
  ) # cache offset for next steps
374
396
 
375
- return cache_seq_len, cache_offset
397
+ # Causal mask for sliding window attention
398
+ attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
399
+ attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, None, :]
400
+
401
+ return cache_seq_len, cache_offset, attn_mask
376
402
 
377
403
  def get_last_layernorm(self) -> nn.LayerNorm:
378
- return self._original_mod.norm
404
+ return self.norm
379
405
 
380
406
  def get_embedding(self) -> nn.Embedding:
381
- return self._original_mod.embed_tokens
407
+ return self.embed_tokens
382
408
 
383
409
  def get_pos_embedding(self) -> nn.Embedding:
384
- raise NotImplementedError(
385
- "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
386
- )
410
+ return self.embed_positions
387
411
 
388
412
  def forward(
389
413
  self,
@@ -398,6 +422,7 @@ class DecoderOnlyModel(nn.Module):
398
422
  global_block_tables: Optional[torch.Tensor] = None,
399
423
  local_block_tables: Optional[torch.Tensor] = None,
400
424
  lora_int_id: Optional[torch.Tensor] = None,
425
+ output_hidden_states: Optional[bool] = None,
401
426
  ):
402
427
  # retrieve input_ids and inputs_embeds
403
428
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -458,13 +483,19 @@ class DecoderOnlyModel(nn.Module):
458
483
 
459
484
  # Get local cache positions for sliding window layers
460
485
  if len(self.sliding_window_layers) > 0:
461
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
486
+ cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
487
+ sliding_cache_pos = (cache_seq_len, cache_offset)
462
488
 
489
+ all_hidden_states = () if output_hidden_states else None
463
490
  for layer_idx, layer in enumerate(self.layers):
491
+ if output_hidden_states:
492
+ all_hidden_states += (hidden_states,)
493
+
464
494
  is_sliding = True if layer_idx in self.sliding_window_layers else False
495
+ is_sliding_decode = is_sliding and self.phase == "decode"
465
496
  hidden_states = layer(
466
497
  hidden_states=hidden_states,
467
- attention_mask=attention_mask,
498
+ attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
468
499
  seq_positions=sliding_cache_pos if is_sliding else seq_positions,
469
500
  past_key_values=past_key_values,
470
501
  cos=cos,
@@ -474,7 +505,10 @@ class DecoderOnlyModel(nn.Module):
474
505
  )
475
506
 
476
507
  hidden_states = self.get_last_layernorm()(hidden_states)
477
- return hidden_states
508
+ if output_hidden_states:
509
+ all_hidden_states += (hidden_states,)
510
+
511
+ return hidden_states, all_hidden_states
478
512
 
479
513
 
480
514
  class DecoderOnlyLayer(nn.Module):
@@ -497,14 +531,23 @@ class DecoderOnlyLayer(nn.Module):
497
531
  self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
498
532
 
499
533
  Attributes:
500
- _original_mod: Reference to original layer for accessing components
501
534
  self_attn: Modified attention mechanism mapped to RBLN ops at compile time
502
535
  phase: Current operation phase ("prefill" or "decode")
503
536
  """
504
537
 
538
+ _PRE_ATTN_LAYERNORM = ["input_layernorm", "ln_1", "self_attn_layer_norm", "pre_feedforward_layernorm"]
539
+ _POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
540
+ _PRE_FF_LAYERNORM_ATTRS = None
541
+ _POST_FF_LAYERNORM_ATTRS = None
542
+
505
543
  def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
506
544
  super().__init__()
507
- self._original_mod = layer
545
+
546
+ self.pre_attention_layernorm = _get_attr_from_candidates(layer, self._PRE_ATTN_LAYERNORM)
547
+ self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
548
+ self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
549
+ self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
550
+ self.mlp = layer.mlp
508
551
  self.self_attn = self_attn
509
552
  self._phase = "prefill"
510
553
  self.lora_config = lora_config
@@ -534,13 +577,19 @@ class DecoderOnlyLayer(nn.Module):
534
577
  self.self_attn.phase = phase
535
578
 
536
579
  def get_pre_attention_layernorm(self) -> nn.LayerNorm:
537
- return self._original_mod.input_layernorm
580
+ return self.pre_attention_layernorm
538
581
 
539
582
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
540
- return self._original_mod.post_attention_layernorm
583
+ return self.post_attention_layernorm
584
+
585
+ def get_pre_feedforward_layernorm(self) -> nn.LayerNorm:
586
+ return self.pre_feedforward_layernorm
587
+
588
+ def get_post_feedforward_layernorm(self) -> nn.LayerNorm:
589
+ return self.post_feedforward_layernorm
541
590
 
542
591
  def get_mlp(self) -> nn.Module:
543
- return self._original_mod.mlp
592
+ return self.mlp
544
593
 
545
594
  def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
546
595
  mlp = self.get_mlp()
@@ -606,6 +655,8 @@ class DecoderOnlyAttention(nn.Module):
606
655
  is_sliding: Whether this is sliding window attention
607
656
  """
608
657
 
658
+ _O_PROJ_ATTRS = ["o_proj", "out_proj", "dense"]
659
+
609
660
  def __init__(
610
661
  self,
611
662
  self_attn,
@@ -613,39 +664,37 @@ class DecoderOnlyAttention(nn.Module):
613
664
  is_sliding=False,
614
665
  ):
615
666
  super().__init__()
616
- self._original_mod = self_attn
667
+ self.config = getattr(self_attn, "config", None)
617
668
  self.rbln_config = rbln_config
618
669
  self.layer_idx = self_attn.layer_idx
619
- self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
620
- self._original_mod.config, "num_attention_heads"
621
- )
622
- self.head_dim = self._original_mod.head_dim
670
+ self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
671
+ self.head_dim = self_attn.head_dim
623
672
  self._phase = "prefill"
624
- self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
625
- self.quantization = rbln_config.quantization
673
+ self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale(self_attn)))
626
674
 
627
- if hasattr(self._original_mod, "num_key_value_heads"):
628
- self.num_key_value_heads = self._original_mod.num_key_value_heads
629
- elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
630
- self.num_key_value_heads = self._original_mod.config.num_key_value_heads
675
+ if hasattr(self_attn, "num_key_value_heads"):
676
+ self.num_key_value_heads = self_attn.num_key_value_heads
677
+ elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
678
+ self.num_key_value_heads = self_attn.config.num_key_value_heads
631
679
  else:
632
680
  self.num_key_value_heads = self.num_heads
633
681
 
634
- self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
635
- self.use_position_ids = rbln_config.use_position_ids
636
682
  self.is_sliding = is_sliding
637
683
  self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
638
684
  self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
639
685
  self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
640
686
  self.lora_config = rbln_config.lora_config
641
687
 
688
+ if hasattr(self_attn, "sinks"):
689
+ self.sinks = self_attn.sinks.data[:, None]
690
+
642
691
  setattr(self, self.get_attention_name(), self.create_attention_op())
643
- self.__post_init__()
692
+ self.__post_init__(self_attn)
644
693
 
645
694
  def _init_lora_weights(self):
646
695
  """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
647
696
  for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
648
- original_linear = getattr(self._original_mod, proj_name)
697
+ original_linear = getattr(self, proj_name)
649
698
  lora_linear = LoRALinear(
650
699
  original_linear=original_linear,
651
700
  lora_config=self.lora_config,
@@ -680,8 +729,7 @@ class DecoderOnlyAttention(nn.Module):
680
729
  self.num_heads,
681
730
  self.head_dim,
682
731
  self.num_key_value_heads,
683
- self.use_attention_mask,
684
- self.use_position_ids,
732
+ rbln_config=self.rbln_config,
685
733
  )
686
734
  elif self.attn_impl == "flash_attn":
687
735
  return FlashAttentionOp(
@@ -689,32 +737,29 @@ class DecoderOnlyAttention(nn.Module):
689
737
  self.head_dim,
690
738
  self.num_key_value_heads,
691
739
  self.kvcache_partition_len,
692
- self.use_attention_mask,
693
- self.use_position_ids,
694
- self.quantization,
740
+ rbln_config=self.rbln_config,
741
+ is_sliding=False,
695
742
  )
696
743
  elif self.attn_impl == "eager":
697
744
  return AttentionOp(
698
745
  self.num_heads,
699
746
  self.head_dim,
700
747
  self.num_key_value_heads,
701
- self.use_attention_mask,
702
- self.use_position_ids,
703
- self.quantization,
748
+ rbln_config=self.rbln_config,
749
+ is_sliding=False,
704
750
  )
705
751
  else:
706
752
  raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
707
753
 
708
- def __post_init__(self):
754
+ def __post_init__(self, self_attn=None):
755
+ self.q_proj = self_attn.q_proj
756
+ self.k_proj = self_attn.k_proj
757
+ self.v_proj = self_attn.v_proj
758
+ self.o_proj = _get_attr_from_candidates(self_attn, self._O_PROJ_ATTRS)
759
+
709
760
  # Initialize LoRA weights if configured, which will replace linear layers
710
761
  if self.lora_config:
711
762
  self._init_lora_weights()
712
- else:
713
- # Use original linear layers if no LoRA
714
- self.q_proj = self._original_mod.q_proj
715
- self.k_proj = self._original_mod.k_proj
716
- self.v_proj = self._original_mod.v_proj
717
- self.o_proj = self._original_mod.o_proj
718
763
 
719
764
  def projection(
720
765
  self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
@@ -745,8 +790,8 @@ class DecoderOnlyAttention(nn.Module):
745
790
  def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
746
791
  return apply_rotary_pos_emb(query_states, key_states, cos, sin)
747
792
 
748
- def get_attn_scale(self):
749
- return 1 / math.sqrt(self.head_dim)
793
+ def get_attn_scale(self, self_attn):
794
+ return 1 / math.sqrt(self_attn.head_dim)
750
795
 
751
796
  def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
752
797
  if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
@@ -803,6 +848,7 @@ class DecoderOnlyAttention(nn.Module):
803
848
  block_size=self.kvcache_block_size,
804
849
  k_scale=k_scale,
805
850
  v_scale=v_scale,
851
+ s_aux=getattr(self, "sinks", None),
806
852
  )
807
853
 
808
854
  # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
@@ -830,23 +876,27 @@ class AttentionOp(nn.Module):
830
876
  num_heads: int,
831
877
  head_dim: int,
832
878
  num_key_value_heads: int,
833
- use_attention_mask: bool,
834
- use_position_ids: bool,
835
- quantization: Optional[RBLNQuantizationConfig] = None,
879
+ rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
880
+ is_sliding: bool = False,
836
881
  ):
837
882
  super().__init__()
838
883
  self.num_heads = num_heads
839
884
  self.head_dim = head_dim
840
885
  self.num_key_value_heads = num_key_value_heads
841
886
  self.phase = "prefill"
842
- self.use_attention_mask = use_attention_mask
843
- self.use_position_ids = use_position_ids
844
- self.quantization = quantization
887
+ self.rbln_config = rbln_config
888
+ self.use_attention_mask = True if is_sliding else rbln_config.use_attention_mask
889
+ self.use_position_ids = rbln_config.use_position_ids
890
+ self.quantization = rbln_config.quantization
845
891
 
846
892
  def get_attn_op_name(self):
847
893
  phase = "decode" if self.phase == "decode" else "prefill"
848
- if self.use_attention_mask and not self.use_position_ids:
849
- attn_op_name = "paged_attn_"
894
+
895
+ if self.use_attention_mask:
896
+ if self.rbln_config.use_position_ids:
897
+ attn_op_name = "paged_causal_attn_"
898
+ else:
899
+ attn_op_name = "paged_attn_"
850
900
  else:
851
901
  attn_op_name = "paged_causal_attn_"
852
902
 
@@ -871,6 +921,7 @@ class AttentionOp(nn.Module):
871
921
  block_size: int,
872
922
  k_scale: Optional[torch.Tensor] = None,
873
923
  v_scale: Optional[torch.Tensor] = None,
924
+ s_aux: Optional[torch.Tensor] = None,
874
925
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
875
926
  """Compute attention with static shapes and explicit cache management.
876
927
 
@@ -887,6 +938,7 @@ class AttentionOp(nn.Module):
887
938
  block_size: Block size for paged attention
888
939
  k_scale: Scale applied to key
889
940
  v_scale: Scale applied to value
941
+ s_aux: Auxiliary states for attention sinks
890
942
 
891
943
  Returns:
892
944
  Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
@@ -895,7 +947,7 @@ class AttentionOp(nn.Module):
895
947
  key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
896
948
  value_state = value_state.unsqueeze(2)
897
949
 
898
- if self.use_attention_mask and not self.use_position_ids:
950
+ if self.use_attention_mask and not self.rbln_config.use_position_ids:
899
951
  attn_mask = attn_mask.unsqueeze(2)
900
952
 
901
953
  if self.phase == "decode":
@@ -927,8 +979,14 @@ class AttentionOp(nn.Module):
927
979
  op_args["mask"] = attn_mask
928
980
 
929
981
  if self.phase == "prefill" or self.phase == "image_prefill":
930
- if not self.use_attention_mask or self.use_position_ids:
931
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
982
+ use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
983
+ if use_image_prefill:
984
+ op_args["is_bidirectional"] = self.phase == "image_prefill"
985
+ else:
986
+ if not self.use_attention_mask:
987
+ op_args["is_bidirectional"] = False
988
+ elif self.use_attention_mask and self.rbln_config.use_position_ids:
989
+ op_args["is_bidirectional"] = True
932
990
 
933
991
  if self.quantization and self.quantization.kv_caches == "fp8":
934
992
  if past_key_state.dtype != torch.float8_e4m3fn:
@@ -936,6 +994,9 @@ class AttentionOp(nn.Module):
936
994
  op_args["k_scale"] = k_scale
937
995
  op_args["v_scale"] = v_scale
938
996
 
997
+ if s_aux is not None:
998
+ op_args["s_aux"] = s_aux
999
+
939
1000
  attn_op_name = self.get_attn_op_name()
940
1001
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
941
1002
  if attn_op is None:
@@ -956,24 +1017,26 @@ class FlashAttentionOp(AttentionOp):
956
1017
  head_dim: int,
957
1018
  num_key_value_heads: int,
958
1019
  kvcache_partition_len: int,
959
- use_attention_mask: bool,
960
- use_position_ids: bool,
961
- quantization: Optional[RBLNQuantizationConfig] = None,
1020
+ rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
1021
+ is_sliding: bool = False,
962
1022
  ):
963
1023
  super().__init__(
964
1024
  num_heads=num_heads,
965
1025
  head_dim=head_dim,
966
1026
  num_key_value_heads=num_key_value_heads,
967
- use_attention_mask=use_attention_mask,
968
- use_position_ids=use_position_ids,
969
- quantization=quantization,
1027
+ rbln_config=rbln_config,
1028
+ is_sliding=is_sliding,
970
1029
  )
971
1030
  self.kvcache_partition_size = kvcache_partition_len
972
1031
 
973
1032
  def get_attn_op_name(self):
974
1033
  phase = "decode" if self.phase == "decode" else "prefill"
975
- if self.use_attention_mask and not self.use_position_ids:
976
- attn_op_name = "paged_flash_attn_"
1034
+
1035
+ if self.use_attention_mask:
1036
+ if self.rbln_config.use_position_ids:
1037
+ attn_op_name = "paged_flash_causal_attn_"
1038
+ else:
1039
+ attn_op_name = "paged_flash_attn_"
977
1040
  else:
978
1041
  attn_op_name = "paged_flash_causal_attn_"
979
1042
 
@@ -998,11 +1061,13 @@ class FlashAttentionOp(AttentionOp):
998
1061
  block_size,
999
1062
  k_scale=None,
1000
1063
  v_scale=None,
1064
+ s_aux=None,
1001
1065
  ):
1002
1066
  # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1003
1067
  key_state = key_state.unsqueeze(2)
1004
1068
  value_state = value_state.unsqueeze(2)
1005
- if self.use_attention_mask and not self.use_position_ids:
1069
+
1070
+ if self.use_attention_mask and not self.rbln_config.use_position_ids:
1006
1071
  attn_mask = attn_mask.unsqueeze(2)
1007
1072
 
1008
1073
  if self.phase == "decode":
@@ -1035,8 +1100,14 @@ class FlashAttentionOp(AttentionOp):
1035
1100
  op_args["mask"] = attn_mask
1036
1101
 
1037
1102
  if self.phase == "prefill" or self.phase == "image_prefill":
1038
- if not self.use_attention_mask or self.use_position_ids:
1039
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1103
+ use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
1104
+ if use_image_prefill:
1105
+ op_args["is_bidirectional"] = self.phase == "image_prefill"
1106
+ else:
1107
+ if not self.use_attention_mask:
1108
+ op_args["is_bidirectional"] = False
1109
+ elif self.use_attention_mask and self.rbln_config.use_position_ids:
1110
+ op_args["is_bidirectional"] = True
1040
1111
 
1041
1112
  if self.quantization and self.quantization.kv_caches == "fp8":
1042
1113
  if past_key_state.dtype != torch.float8_e4m3fn:
@@ -1044,6 +1115,9 @@ class FlashAttentionOp(AttentionOp):
1044
1115
  op_args["k_scale"] = k_scale
1045
1116
  op_args["v_scale"] = v_scale
1046
1117
 
1118
+ if s_aux is not None:
1119
+ op_args["s_aux"] = s_aux
1120
+
1047
1121
  attn_op_name = self.get_attn_op_name()
1048
1122
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1049
1123
  if attn_op is None:
@@ -1058,6 +1132,22 @@ class FlashAttentionOp(AttentionOp):
1058
1132
 
1059
1133
 
1060
1134
  class SlidingWindowAttentionOp(AttentionOp):
1135
+ def __init__(
1136
+ self,
1137
+ num_heads: int,
1138
+ head_dim: int,
1139
+ num_key_value_heads: int,
1140
+ rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
1141
+ ):
1142
+ super().__init__(
1143
+ num_heads=num_heads,
1144
+ head_dim=head_dim,
1145
+ num_key_value_heads=num_key_value_heads,
1146
+ rbln_config=rbln_config,
1147
+ is_sliding=True,
1148
+ )
1149
+ self.quantization = None # Sliding window attention does not support quantization
1150
+
1061
1151
  def get_attn_op_name(self):
1062
1152
  phase = "decode" if self.phase == "decode" else "prefill"
1063
1153
  if not self.use_attention_mask:
@@ -1080,6 +1170,7 @@ class SlidingWindowAttentionOp(AttentionOp):
1080
1170
  block_size: int,
1081
1171
  k_scale: Optional[torch.Tensor] = None,
1082
1172
  v_scale: Optional[torch.Tensor] = None,
1173
+ s_aux: Optional[torch.Tensor] = None,
1083
1174
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1084
1175
  assert self.quantization is None, "Sliding window attention does not support quantization"
1085
1176
  assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
@@ -1115,7 +1206,19 @@ class SlidingWindowAttentionOp(AttentionOp):
1115
1206
  }
1116
1207
 
1117
1208
  if self.phase == "prefill" or self.phase == "image_prefill":
1118
- op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1209
+ use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
1210
+ if use_image_prefill:
1211
+ op_args["is_bidirectional"] = self.phase == "image_prefill"
1212
+ else:
1213
+ if self.use_attention_mask and self.rbln_config.use_position_ids:
1214
+ op_args["is_bidirectional"] = True
1215
+ else:
1216
+ op_args["is_bidirectional"] = False
1217
+ elif self.phase == "decode":
1218
+ op_args["attn_mask"] = attn_mask
1219
+
1220
+ if s_aux is not None:
1221
+ op_args["s_aux"] = s_aux
1119
1222
 
1120
1223
  attn_op_name = self.get_attn_op_name()
1121
1224
  attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
@@ -1145,7 +1248,7 @@ class RotaryEmbedding(nn.Module):
1145
1248
  else:
1146
1249
  rope_type = "default"
1147
1250
 
1148
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1251
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, "cpu", max_seq_len_cached)
1149
1252
  cache_position = torch.arange(0, max_seq_len_cached)
1150
1253
  cache_position_expanded = cache_position[:, None]
1151
1254
 
@@ -1222,3 +1325,22 @@ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tu
1222
1325
  query_states = torch.cat((query_rot, query_pass), dim=-1)
1223
1326
  key_states = torch.cat((key_rot, key_pass), dim=-1)
1224
1327
  return query_states, key_states
1328
+
1329
+
1330
+ def _get_attr_from_candidates(
1331
+ src: object,
1332
+ candidates: Optional[List[str]] = None,
1333
+ ):
1334
+ """
1335
+ Get an attribute from a list of candidate names.
1336
+
1337
+ - If `candidates` is None, this attribute is treated as optional and returns None.
1338
+ - Otherwise, returns `getattr(src, name)` for the first `name` in `candidates` that exists on `src`.
1339
+ - Raises AttributeError if `candidates` is provided but none of the names exist on `src`.
1340
+ """
1341
+ if candidates is None:
1342
+ return None
1343
+ for name in candidates:
1344
+ if hasattr(src, name):
1345
+ return getattr(src, name)
1346
+ raise AttributeError(f"None of the attributes {candidates} exist in {src}")