optimum-rbln 0.2.1a0__py3-none-any.whl → 0.2.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. optimum/rbln/__init__.py +3 -10
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +1 -10
  4. optimum/rbln/diffusers/modeling_diffusers.py +1 -10
  5. optimum/rbln/diffusers/models/__init__.py +1 -10
  6. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -10
  7. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -10
  8. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -10
  9. optimum/rbln/diffusers/models/controlnet.py +1 -10
  10. optimum/rbln/diffusers/models/transformers/__init__.py +1 -10
  11. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -10
  12. optimum/rbln/diffusers/models/unets/__init__.py +1 -10
  13. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -10
  14. optimum/rbln/diffusers/pipelines/__init__.py +1 -10
  15. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +1 -10
  16. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -10
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +1 -10
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -10
  19. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -10
  20. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -10
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -10
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -10
  23. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -10
  24. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -10
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +1 -10
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -10
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -10
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -10
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -10
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -10
  31. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -10
  32. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -10
  33. optimum/rbln/modeling.py +1 -10
  34. optimum/rbln/modeling_base.py +1 -10
  35. optimum/rbln/modeling_config.py +1 -10
  36. optimum/rbln/ops/__init__.py +1 -10
  37. optimum/rbln/ops/attn.py +9 -18
  38. optimum/rbln/ops/flash_attn.py +5 -14
  39. optimum/rbln/ops/kv_cache_update.py +1 -10
  40. optimum/rbln/transformers/__init__.py +5 -12
  41. optimum/rbln/transformers/modeling_alias.py +1 -14
  42. optimum/rbln/transformers/modeling_generic.py +40 -21
  43. optimum/rbln/transformers/modeling_rope_utils.py +28 -0
  44. optimum/rbln/transformers/models/__init__.py +3 -12
  45. optimum/rbln/transformers/models/auto/__init__.py +1 -10
  46. optimum/rbln/transformers/models/auto/auto_factory.py +1 -10
  47. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -10
  48. optimum/rbln/transformers/models/bart/__init__.py +1 -10
  49. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -10
  50. optimum/rbln/transformers/models/bart/modeling_bart.py +14 -13
  51. optimum/rbln/transformers/models/bert/__init__.py +2 -11
  52. optimum/rbln/transformers/models/bert/modeling_bert.py +23 -13
  53. optimum/rbln/transformers/models/clip/__init__.py +1 -10
  54. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -10
  55. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -10
  56. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +54 -69
  57. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +14 -14
  58. optimum/rbln/transformers/models/dpt/__init__.py +1 -10
  59. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -10
  60. optimum/rbln/transformers/models/exaone/__init__.py +1 -10
  61. optimum/rbln/transformers/models/exaone/exaone_architecture.py +1 -10
  62. optimum/rbln/transformers/models/exaone/modeling_exaone.py +1 -10
  63. optimum/rbln/transformers/models/gemma/__init__.py +1 -10
  64. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -10
  65. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -10
  66. optimum/rbln/transformers/models/gpt2/__init__.py +1 -10
  67. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +1 -10
  68. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -10
  69. optimum/rbln/transformers/models/llama/__init__.py +1 -10
  70. optimum/rbln/transformers/models/llama/llama_architecture.py +1 -10
  71. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -10
  72. optimum/rbln/transformers/models/llava_next/__init__.py +1 -10
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +95 -89
  74. optimum/rbln/transformers/models/midm/__init__.py +1 -10
  75. optimum/rbln/transformers/models/midm/midm_architecture.py +1 -10
  76. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -10
  77. optimum/rbln/transformers/models/mistral/__init__.py +1 -10
  78. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -10
  79. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -10
  80. optimum/rbln/transformers/models/phi/__init__.py +1 -10
  81. optimum/rbln/transformers/models/phi/modeling_phi.py +1 -10
  82. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -10
  83. optimum/rbln/transformers/models/qwen2/__init__.py +1 -10
  84. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +1 -10
  85. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +1 -10
  86. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -10
  87. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -10
  88. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +16 -42
  89. optimum/rbln/transformers/models/t5/__init__.py +1 -10
  90. optimum/rbln/transformers/models/t5/modeling_t5.py +14 -15
  91. optimum/rbln/transformers/models/t5/t5_architecture.py +30 -16
  92. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -10
  93. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -10
  94. optimum/rbln/transformers/models/whisper/__init__.py +1 -10
  95. optimum/rbln/transformers/models/whisper/generation_whisper.py +2 -11
  96. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -10
  97. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -10
  98. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -10
  99. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +24 -12
  100. optimum/rbln/transformers/utils/rbln_quantization.py +6 -10
  101. optimum/rbln/utils/__init__.py +1 -10
  102. optimum/rbln/utils/decorator_utils.py +1 -10
  103. optimum/rbln/utils/hub.py +1 -10
  104. optimum/rbln/utils/import_utils.py +1 -10
  105. optimum/rbln/utils/logging.py +1 -10
  106. optimum/rbln/utils/model_utils.py +1 -10
  107. optimum/rbln/utils/runtime_utils.py +1 -10
  108. optimum/rbln/utils/save_utils.py +2 -10
  109. optimum/rbln/utils/submodule.py +1 -10
  110. {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/METADATA +6 -4
  111. optimum_rbln-0.2.1a2.dist-info/RECORD +114 -0
  112. optimum_rbln-0.2.1a2.dist-info/licenses/LICENSE +201 -0
  113. optimum_rbln-0.2.1a0.dist-info/RECORD +0 -114
  114. optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +0 -288
  115. {optimum_rbln-0.2.1a0.dist-info → optimum_rbln-0.2.1a2.dist-info}/WHEEL +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .phi_architecture import PhiWrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import TYPE_CHECKING, Optional, Tuple
25
16
 
26
17
  import torch
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_qwen2 import RBLNQwen2ForCausalLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ....utils import logging
25
16
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
17
  from .qwen2_architecture import QWEN2Wrapper
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
25
16
 
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  import logging
26
17
  from abc import ABC
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import Tuple
25
16
 
26
17
  import torch
@@ -429,7 +420,7 @@ class Seq2SeqSelfAttention(nn.Module):
429
420
  pass
430
421
 
431
422
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
432
- return tensor.view(bsz, 1, seq_len, 1, self.num_heads, self.head_dim).transpose(2, 4)
423
+ return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
433
424
 
434
425
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
435
426
  """Projects input hidden states into query, key, and value representations.
@@ -459,38 +450,21 @@ class Seq2SeqSelfAttention(nn.Module):
459
450
  key_states = self._shape(key_states, -1, bsz)
460
451
  value_states = self._shape(value_states, -1, bsz)
461
452
 
462
- all_key_states = []
463
- all_value_states = []
464
- all_attn_output = []
465
- for b_idx in range(bsz):
466
- query_state = query_states[b_idx]
467
- key_state = key_states[b_idx]
468
- value_state = value_states[b_idx]
469
- attn_mask = attention_mask[b_idx].unsqueeze(0).unsqueeze(2)
470
- past_key_state = past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim)
471
- past_value_state = past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim)
472
-
473
- attn_output, key_state, value_state = self.attn_decode(
474
- query_state,
475
- key_state,
476
- value_state,
477
- attn_mask,
478
- past_key_state,
479
- past_value_state,
480
- cache_position[b_idx][0],
481
- torch.tensor(1.0, dtype=torch.float32), # scale
482
- )
483
-
484
- attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim).transpose(1, 2)
485
- attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
486
-
487
- all_key_states.append(key_state.squeeze(2))
488
- all_value_states.append(value_state.squeeze(2))
489
- all_attn_output.append(attn_output)
453
+ attn_output, key_states, value_states = self.attn_decode(
454
+ query_states,
455
+ key_states,
456
+ value_states,
457
+ attention_mask.unsqueeze(
458
+ 2
459
+ ), # Unsqueeze group axis since CustomKernel expects it for group query attention
460
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
461
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
462
+ cache_position.squeeze(1),
463
+ torch.tensor(1.0, dtype=torch.float32), # scale
464
+ )
490
465
 
491
- key_states = torch.cat(all_key_states, dim=0)
492
- value_states = torch.cat(all_value_states, dim=0)
493
- attn_output = torch.cat(all_attn_output, dim=0)
466
+ attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
467
+ attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
494
468
 
495
469
  attn_output = self.out_proj(attn_output)
496
470
  present_key_value = (key_states, value_states)
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,20 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
26
17
 
27
18
  import torch
28
- import transformers
29
19
  from transformers import (
30
20
  AutoModelForTextEncoding,
31
21
  PretrainedConfig,
@@ -130,20 +120,29 @@ class RBLNT5EncoderModel(RBLNModel):
130
120
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
131
121
  raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
132
122
 
123
+ signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
124
+
133
125
  if rbln_model_input_names is None:
134
126
  for tokenizer in preprocessors:
135
127
  if hasattr(tokenizer, "model_input_names"):
136
- rbln_model_input_names = tokenizer.model_input_names
128
+ rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
129
+
130
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
131
+ if invalid_params:
132
+ raise ValueError(f"Invalid model input names: {invalid_params}")
137
133
  break
138
134
  if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
139
135
  rbln_model_input_names = cls.rbln_model_input_names
140
136
  elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
141
- original_model_class = getattr(transformers, model_config.architectures[0])
142
- input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
143
137
  raise ValueError(
144
138
  "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
145
- f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(input_names_order)})"
139
+ f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(signature_params)})"
146
140
  )
141
+ else:
142
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
143
+ if invalid_params:
144
+ raise ValueError(f"Invalid model input names: {invalid_params}")
145
+ rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
147
146
 
148
147
  if rbln_batch_size is None:
149
148
  rbln_batch_size = 1
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import Tuple
25
16
 
26
17
  import torch
@@ -156,6 +147,11 @@ class T5CrossAttention(nn.Module):
156
147
  def __init__(self, attn):
157
148
  super().__init__()
158
149
  self.attn = attn
150
+ self.q = attn.q
151
+ self.o = attn.o
152
+ self.n_heads = attn.n_heads
153
+ self.key_value_proj_dim = attn.key_value_proj_dim
154
+ self.inner_dim = attn.inner_dim
159
155
 
160
156
  def forward(
161
157
  self,
@@ -164,9 +160,27 @@ class T5CrossAttention(nn.Module):
164
160
  attention_mask: torch.Tensor = None,
165
161
  key_value_states: torch.Tensor = None,
166
162
  ):
167
- return self.attn(
168
- hidden_states=hidden_states,
169
- past_key_value=past_key_value,
170
- position_bias=attention_mask,
171
- key_value_states=key_value_states,
172
- )
163
+ batch_size = hidden_states.shape[0]
164
+
165
+ query_states = self.q(hidden_states)
166
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
167
+
168
+ # reuse k,v, cross_attentions
169
+ key_states = past_key_value[0]
170
+ value_states = past_key_value[1]
171
+
172
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
173
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
174
+ scores += attention_mask
175
+
176
+ # (batch_size, n_heads, seq_length, key_length)
177
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
178
+ attn_output = torch.matmul(attn_weights, value_states)
179
+
180
+ attn_output = attn_output.transpose(1, 2).contiguous()
181
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
182
+ attn_output = self.o(attn_output)
183
+
184
+ outputs = (attn_output, past_key_value)
185
+
186
+ return outputs
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import logging
25
16
  from typing import TYPE_CHECKING, Any, Dict, Union
26
17
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_whisper import RBLNWhisperForConditionalGeneration
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
1
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Copyright 2024 Rebellions Inc.
15
+ # Copyright 2025 Rebellions Inc. All rights reserved.
16
16
 
17
17
  # Licensed under the Apache License, Version 2.0 (the "License");
18
18
  # you may not use this file except in compliance with the License.
@@ -26,15 +26,6 @@
26
26
  # See the License for the specific language governing permissions and
27
27
  # limitations under the License.
28
28
 
29
- # Portions of this software are licensed under the Apache License,
30
- # Version 2.0. See the NOTICE file distributed with this work for
31
- # additional information regarding copyright ownership.
32
-
33
- # All other portions of this software, including proprietary code,
34
- # are the intellectual property of Rebellions Inc. and may not be
35
- # copied, modified, or distributed without prior written permission
36
- # from Rebellions Inc.
37
-
38
29
  """
39
30
  Generation utilities for Whisper.
40
31
  Modified from `transformers.models.whisper.generation_whisper.py`
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import inspect
25
16
  import logging
26
17
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from typing import Optional, Tuple, Union
25
16
 
26
17
  import torch
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,13 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  from .modeling_xlm_roberta import RBLNXLMRobertaModel
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
15
+ import inspect
24
16
  import logging
25
17
  from typing import TYPE_CHECKING, Optional, Union
26
18
 
@@ -66,9 +58,29 @@ class RBLNXLMRobertaModel(RBLNModel):
66
58
  if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
67
59
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
68
60
 
61
+ signature_params = inspect.signature(cls.hf_class.forward).parameters.keys()
62
+
69
63
  if rbln_model_input_names is None:
70
- # These are BERT's inputs
71
- rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
64
+ for tokenizer in preprocessors:
65
+ if hasattr(tokenizer, "model_input_names"):
66
+ rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
67
+
68
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
69
+ if invalid_params:
70
+ raise ValueError(f"Invalid model input names: {invalid_params}")
71
+ break
72
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
73
+ rbln_model_input_names = cls.rbln_model_input_names
74
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
75
+ raise ValueError(
76
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
77
+ f"and be sure to make the order of the inputs same as XLMRobertaModel forward() arguments like ({list(signature_params)})"
78
+ )
79
+ else:
80
+ invalid_params = set(rbln_model_input_names) - set(signature_params)
81
+ if invalid_params:
82
+ raise ValueError(f"Invalid model input names: {invalid_params}")
83
+ rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
72
84
 
73
85
  if rbln_batch_size is None:
74
86
  rbln_batch_size = 1
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Rebellions Inc.
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
2
 
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,15 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
15
  import functools
25
16
  import glob
26
17
  import os
@@ -135,6 +126,8 @@ def update_layers_to_quantize(module: torch.nn.Module) -> None:
135
126
  """
136
127
  Updates specified linear layers to quantized (qlinear) layers in the given module.
137
128
  """
129
+
130
+ logger.debug("Updating layers to be quantized") # TODO(jongho): remove.
138
131
  processed_layers = []
139
132
 
140
133
  for name, layer in module.named_modules():
@@ -151,6 +144,7 @@ def load_weights(model, model_id, n_layer=None):
151
144
  """
152
145
  Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
153
146
  """
147
+ logger.debug("Loading the quantized weights into the CPU.") # TODO(jongho): remove.
154
148
 
155
149
  model_params = dict(model.named_parameters(recurse=True))
156
150
  model_buffers = dict(model.named_buffers(recurse=True))
@@ -172,6 +166,8 @@ def load_weights(model, model_id, n_layer=None):
172
166
  elif key in model_buffers:
173
167
  model_buffers[key].data.copy_(value)
174
168
 
169
+ logger.debug("Loaded the quantized weights into the CPU.")
170
+
175
171
 
176
172
  def is_target_for_qlinear_replacement(layer_name: str, layer: torch.nn.Module) -> bool:
177
173
  """