optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. optimum/rbln/__init__.py +173 -35
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +816 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +111 -137
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +56 -71
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
  31. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
  33. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
  34. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
  36. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
  38. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
  42. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
  43. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
  44. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
  45. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
  46. optimum/rbln/modeling.py +66 -40
  47. optimum/rbln/modeling_base.py +111 -86
  48. optimum/rbln/ops/__init__.py +4 -7
  49. optimum/rbln/ops/attn.py +271 -205
  50. optimum/rbln/ops/flash_attn.py +161 -67
  51. optimum/rbln/ops/kv_cache_update.py +4 -40
  52. optimum/rbln/ops/linear.py +25 -0
  53. optimum/rbln/transformers/__init__.py +97 -8
  54. optimum/rbln/transformers/configuration_alias.py +49 -0
  55. optimum/rbln/transformers/configuration_generic.py +142 -0
  56. optimum/rbln/transformers/modeling_generic.py +193 -280
  57. optimum/rbln/transformers/models/__init__.py +120 -32
  58. optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
  59. optimum/rbln/transformers/models/bart/__init__.py +2 -0
  60. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  61. optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
  62. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  63. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  64. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  65. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  66. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  67. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  68. optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
  69. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
  71. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
  72. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  73. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  74. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  75. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  76. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  77. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  78. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  79. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  80. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  81. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  82. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
  83. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
  84. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  85. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  86. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  87. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  88. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
  89. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  90. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  91. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  92. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  93. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  94. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  97. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  102. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
  104. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
  105. optimum/rbln/transformers/models/t5/__init__.py +2 -0
  106. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  107. optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
  108. optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
  109. optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
  110. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  111. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
  112. optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  114. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  115. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  116. optimum/rbln/transformers/models/whisper/__init__.py +2 -0
  117. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  118. optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
  119. optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  123. optimum/rbln/utils/hub.py +2 -2
  124. optimum/rbln/utils/import_utils.py +23 -6
  125. optimum/rbln/utils/model_utils.py +4 -4
  126. optimum/rbln/utils/runtime_utils.py +33 -2
  127. optimum/rbln/utils/submodule.py +36 -44
  128. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
  129. optimum_rbln-0.7.4.dist-info/RECORD +169 -0
  130. optimum/rbln/modeling_config.py +0 -310
  131. optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
  132. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
  133. {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -32,28 +32,72 @@ _import_structure = {
32
32
  "RBLNAutoModelForSpeechSeq2Seq",
33
33
  "RBLNAutoModelForVision2Seq",
34
34
  ],
35
- "bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
36
- "bert": ["RBLNBertModel", "RBLNBertForQuestionAnswering", "RBLNBertForMaskedLM"],
35
+ "bart": [
36
+ "RBLNBartForConditionalGeneration",
37
+ "RBLNBartModel",
38
+ "RBLNBartForConditionalGenerationConfig",
39
+ "RBLNBartModelConfig",
40
+ ],
41
+ "bert": [
42
+ "RBLNBertModel",
43
+ "RBLNBertModelConfig",
44
+ "RBLNBertForQuestionAnswering",
45
+ "RBLNBertForQuestionAnsweringConfig",
46
+ "RBLNBertForMaskedLM",
47
+ "RBLNBertForMaskedLMConfig",
48
+ ],
37
49
  "clip": [
38
50
  "RBLNCLIPTextModel",
51
+ "RBLNCLIPTextModelConfig",
39
52
  "RBLNCLIPTextModelWithProjection",
53
+ "RBLNCLIPTextModelWithProjectionConfig",
40
54
  "RBLNCLIPVisionModel",
55
+ "RBLNCLIPVisionModelConfig",
41
56
  "RBLNCLIPVisionModelWithProjection",
57
+ "RBLNCLIPVisionModelWithProjectionConfig",
58
+ ],
59
+ "qwen2_5_vl": [
60
+ "RBLNQwen2_5_VisionTransformerPretrainedModel",
61
+ "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
62
+ "RBLNQwen2_5_VLForConditionalGeneration",
63
+ "RBLNQwen2_5_VLForConditionalGenerationConfig",
64
+ ],
65
+ "decoderonly": [
66
+ "RBLNDecoderOnlyModelForCausalLM",
67
+ "RBLNDecoderOnlyModelForCausalLMConfig",
68
+ ],
69
+ "dpt": [
70
+ "RBLNDPTForDepthEstimation",
71
+ "RBLNDPTForDepthEstimationConfig",
72
+ ],
73
+ "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
74
+ "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
75
+ "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
76
+ "idefics3": [
77
+ "RBLNIdefics3VisionTransformer",
78
+ "RBLNIdefics3ForConditionalGeneration",
79
+ "RBLNIdefics3ForConditionalGenerationConfig",
80
+ "RBLNIdefics3VisionTransformerConfig",
81
+ ],
82
+ "llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig"],
83
+ "llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
84
+ "midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
85
+ "mistral": ["RBLNMistralForCausalLM", "RBLNMistralForCausalLMConfig"],
86
+ "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig"],
87
+ "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig"],
88
+ "time_series_transformers": [
89
+ "RBLNTimeSeriesTransformerForPrediction",
90
+ "RBLNTimeSeriesTransformerForPredictionConfig",
91
+ ],
92
+ "t5": [
93
+ "RBLNT5EncoderModel",
94
+ "RBLNT5ForConditionalGeneration",
95
+ "RBLNT5EncoderModelConfig",
96
+ "RBLNT5ForConditionalGenerationConfig",
42
97
  ],
43
- "dpt": ["RBLNDPTForDepthEstimation"],
44
- "exaone": ["RBLNExaoneForCausalLM"],
45
- "gemma": ["RBLNGemmaForCausalLM"],
46
- "gpt2": ["RBLNGPT2LMHeadModel"],
47
- "llama": ["RBLNLlamaForCausalLM"],
48
- "llava_next": ["RBLNLlavaNextForConditionalGeneration"],
49
- "midm": ["RBLNMidmLMHeadModel"],
50
- "mistral": ["RBLNMistralForCausalLM"],
51
- "phi": ["RBLNPhiForCausalLM"],
52
- "qwen2": ["RBLNQwen2ForCausalLM"],
53
- "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
54
- "wav2vec2": ["RBLNWav2Vec2ForCTC"],
55
- "whisper": ["RBLNWhisperForConditionalGeneration"],
56
- "xlm_roberta": ["RBLNXLMRobertaModel"],
98
+ "wav2vec2": ["RBLNWav2Vec2ForCTC", "RBLNWav2Vec2ForCTCConfig"],
99
+ "whisper": ["RBLNWhisperForConditionalGeneration", "RBLNWhisperForConditionalGenerationConfig"],
100
+ "xlm_roberta": ["RBLNXLMRobertaModel", "RBLNXLMRobertaModelConfig"],
57
101
  }
58
102
 
59
103
  if TYPE_CHECKING:
@@ -71,28 +115,72 @@ if TYPE_CHECKING:
71
115
  RBLNAutoModelForSpeechSeq2Seq,
72
116
  RBLNAutoModelForVision2Seq,
73
117
  )
74
- from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
75
- from .bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
118
+ from .bart import (
119
+ RBLNBartForConditionalGeneration,
120
+ RBLNBartForConditionalGenerationConfig,
121
+ RBLNBartModel,
122
+ RBLNBartModelConfig,
123
+ )
124
+ from .bert import (
125
+ RBLNBertForMaskedLM,
126
+ RBLNBertForMaskedLMConfig,
127
+ RBLNBertForQuestionAnswering,
128
+ RBLNBertForQuestionAnsweringConfig,
129
+ RBLNBertModel,
130
+ RBLNBertModelConfig,
131
+ )
76
132
  from .clip import (
77
133
  RBLNCLIPTextModel,
134
+ RBLNCLIPTextModelConfig,
78
135
  RBLNCLIPTextModelWithProjection,
136
+ RBLNCLIPTextModelWithProjectionConfig,
79
137
  RBLNCLIPVisionModel,
138
+ RBLNCLIPVisionModelConfig,
80
139
  RBLNCLIPVisionModelWithProjection,
140
+ RBLNCLIPVisionModelWithProjectionConfig,
141
+ )
142
+ from .decoderonly import (
143
+ RBLNDecoderOnlyModelForCausalLM,
144
+ RBLNDecoderOnlyModelForCausalLMConfig,
145
+ )
146
+ from .dpt import (
147
+ RBLNDPTForDepthEstimation,
148
+ RBLNDPTForDepthEstimationConfig,
149
+ )
150
+ from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
151
+ from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
152
+ from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
153
+ from .idefics3 import (
154
+ RBLNIdefics3ForConditionalGeneration,
155
+ RBLNIdefics3ForConditionalGenerationConfig,
156
+ RBLNIdefics3VisionTransformer,
157
+ RBLNIdefics3VisionTransformerConfig,
158
+ )
159
+ from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
160
+ from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
161
+ from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
162
+ from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
163
+ from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig
164
+ from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig
165
+ from .qwen2_5_vl import (
166
+ RBLNQwen2_5_VisionTransformerPretrainedModel,
167
+ RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
168
+ RBLNQwen2_5_VLForConditionalGeneration,
169
+ RBLNQwen2_5_VLForConditionalGenerationConfig,
170
+ )
171
+ from .t5 import (
172
+ RBLNT5EncoderModel,
173
+ RBLNT5EncoderModelConfig,
174
+ RBLNT5ForConditionalGeneration,
175
+ RBLNT5ForConditionalGenerationConfig,
176
+ )
177
+ from .time_series_transformers import (
178
+ RBLNTimeSeriesTransformerForPrediction,
179
+ RBLNTimeSeriesTransformerForPredictionConfig,
81
180
  )
82
- from .dpt import RBLNDPTForDepthEstimation
83
- from .exaone import RBLNExaoneForCausalLM
84
- from .gemma import RBLNGemmaForCausalLM
85
- from .gpt2 import RBLNGPT2LMHeadModel
86
- from .llama import RBLNLlamaForCausalLM
87
- from .llava_next import RBLNLlavaNextForConditionalGeneration
88
- from .midm import RBLNMidmLMHeadModel
89
- from .mistral import RBLNMistralForCausalLM
90
- from .phi import RBLNPhiForCausalLM
91
- from .qwen2 import RBLNQwen2ForCausalLM
92
- from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
93
- from .wav2vec2 import RBLNWav2Vec2ForCTC
94
- from .whisper import RBLNWhisperForConditionalGeneration
95
- from .xlm_roberta import RBLNXLMRobertaModel
181
+ from .wav2vec2 import RBLNWav2Vec2ForCTC, RBLNWav2Vec2ForCTCConfig
182
+ from .whisper import RBLNWhisperForConditionalGeneration, RBLNWhisperForConditionalGenerationConfig
183
+ from .xlm_roberta import RBLNXLMRobertaModel, RBLNXLMRobertaModelConfig
96
184
 
97
185
  else:
98
186
  import sys
@@ -20,8 +20,8 @@ from transformers import AutoConfig, PretrainedConfig
20
20
  from transformers.dynamic_module_utils import get_class_from_dynamic_module
21
21
  from transformers.models.auto.auto_factory import _get_model_class
22
22
 
23
+ from optimum.rbln.configuration_utils import RBLNAutoConfig
23
24
  from optimum.rbln.modeling_base import RBLNBaseModel
24
- from optimum.rbln.modeling_config import RBLNConfig
25
25
  from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
26
26
 
27
27
 
@@ -48,7 +48,7 @@ class _BaseAutoModelClass:
48
48
 
49
49
  Args:
50
50
  pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
51
- export (bool): Whether to infer the class based on Hugging Face (HF) architecture.
51
+ export (bool): Whether to infer the class based on HuggingFace (HF) architecture.
52
52
  kwargs: Additional arguments for configuration and loading.
53
53
 
54
54
  Returns:
@@ -86,14 +86,14 @@ class _BaseAutoModelClass:
86
86
  **kwargs,
87
87
  ):
88
88
  """
89
- Infer the Hugging Face model class based on the configuration or model name.
89
+ Infer the HuggingFace model class based on the configuration or model name.
90
90
 
91
91
  Args:
92
92
  pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
93
93
  kwargs: Additional arguments for configuration and loading.
94
94
 
95
95
  Returns:
96
- PretrainedModel: The inferred Hugging Face model class.
96
+ PretrainedModel: The inferred HuggingFace model class.
97
97
  """
98
98
 
99
99
  # Try to load configuration if provided or retrieve it from the model ID
@@ -154,9 +154,9 @@ class _BaseAutoModelClass:
154
154
  model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
155
155
  model_id=pretrained_model_name_or_path, **filtered_kwargs
156
156
  )
157
- rbln_config = RBLNConfig.load(model_path_subfolder)
157
+ rbln_config = RBLNAutoConfig.load(model_path_subfolder)
158
158
 
159
- return rbln_config.meta["cls"]
159
+ return rbln_config.rbln_model_cls_name
160
160
 
161
161
  @classmethod
162
162
  def from_pretrained(
@@ -12,4 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_bart import RBLNBartForConditionalGenerationConfig, RBLNBartModelConfig
15
17
  from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ pass
21
+
22
+
23
+ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
+ pass
@@ -13,109 +13,36 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
16
+ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
- from transformers import BartForConditionalGeneration, PretrainedConfig, PreTrainedModel
18
+ from transformers import BartForConditionalGeneration, PreTrainedModel
19
19
 
20
- from ....modeling import RBLNModel
21
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
22
20
  from ....utils.logging import get_logger
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
23
22
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
24
23
  from .bart_architecture import BartWrapper
24
+ from .configuration_bart import RBLNBartForConditionalGenerationConfig
25
25
 
26
26
 
27
27
  logger = get_logger()
28
28
 
29
29
 
30
30
  if TYPE_CHECKING:
31
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
31
+ from transformers import PreTrainedModel
32
32
 
33
33
 
34
- class RBLNBartModel(RBLNModel):
35
- @classmethod
36
- def _get_rbln_config(
37
- cls,
38
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
39
- model_config: Optional["PretrainedConfig"] = None,
40
- rbln_kwargs: Dict[str, Any] = {},
41
- ) -> RBLNConfig:
42
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
43
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
44
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
45
-
46
- max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
47
-
48
- if rbln_max_seq_len is None:
49
- rbln_max_seq_len = max_position_embeddings
50
- if rbln_max_seq_len is None:
51
- for tokenizer in preprocessors:
52
- if hasattr(tokenizer, "model_max_length"):
53
- rbln_max_seq_len = tokenizer.model_max_length
54
- break
55
- if rbln_max_seq_len is None:
56
- raise ValueError("`rbln_max_seq_len` should be specified!")
57
-
58
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
59
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
60
-
61
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
62
-
63
- if rbln_model_input_names is None:
64
- for tokenizer in preprocessors:
65
- if hasattr(tokenizer, "model_input_names"):
66
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
67
- # BartModel's forward() does not take token_type_ids as input.
68
- # (Added because some of the tokenizers includes 'token_type_ids')
69
- if "token_type_ids" in rbln_model_input_names:
70
- rbln_model_input_names.remove("token_type_ids")
71
-
72
- invalid_params = set(rbln_model_input_names) - set(signature_params)
73
- if invalid_params:
74
- raise ValueError(f"Invalid model input names: {invalid_params}")
75
- break
76
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
77
- rbln_model_input_names = cls.rbln_model_input_names
78
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
79
- raise ValueError(
80
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
81
- f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
82
- )
83
- else:
84
- invalid_params = set(rbln_model_input_names) - set(signature_params)
85
- if invalid_params:
86
- raise ValueError(f"Invalid model input names: {invalid_params}")
87
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
88
-
89
- if rbln_batch_size is None:
90
- rbln_batch_size = 1
91
-
92
- input_info = [
93
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
94
- for model_input_name in rbln_model_input_names
95
- ]
96
-
97
- enc_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="encoder")
98
- dec_compile_config = RBLNCompileConfig(input_info=input_info, compiled_model_name="decoder")
99
-
100
- rbln_config = RBLNConfig(
101
- rbln_cls=cls.__name__,
102
- compile_cfgs=[enc_compile_config, dec_compile_config],
103
- rbln_kwargs=rbln_kwargs,
104
- )
105
-
106
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
107
- return rbln_config
34
+ class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
+ pass
108
36
 
109
37
 
110
38
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
39
+ support_causal_attn = True
40
+
111
41
  @classmethod
112
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
113
- enc_max_seq_len = (
114
- rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
42
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNBartForConditionalGenerationConfig):
43
+ return BartWrapper(
44
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
115
45
  )
116
- use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
117
-
118
- return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
119
46
 
120
47
  def __getattr__(self, __name: str) -> Any:
121
48
  def redirect(func):
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_bert import RBLNBertForMaskedLMConfig, RBLNBertForQuestionAnsweringConfig, RBLNBertModelConfig
15
16
  from .modeling_bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import (
16
+ RBLNModelForMaskedLMConfig,
17
+ RBLNModelForQuestionAnsweringConfig,
18
+ RBLNTransformerEncoderForFeatureExtractionConfig,
19
+ )
20
+
21
+
22
+ class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
23
+ pass
24
+
25
+
26
+ class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
27
+ pass
28
+
29
+
30
+ class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
31
+ pass
@@ -12,92 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import inspect
16
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
17
-
18
- from transformers import PretrainedConfig
19
-
20
- from ....modeling import RBLNModel
21
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
22
15
  from ....utils.logging import get_logger
23
- from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForQuestionAnswering
16
+ from ...modeling_generic import (
17
+ RBLNModelForMaskedLM,
18
+ RBLNModelForQuestionAnswering,
19
+ RBLNTransformerEncoderForFeatureExtraction,
20
+ )
24
21
 
25
22
 
26
23
  logger = get_logger(__name__)
27
24
 
28
- if TYPE_CHECKING:
29
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
30
-
31
-
32
- class RBLNBertModel(RBLNModel):
33
- @classmethod
34
- def _get_rbln_config(
35
- cls,
36
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
37
- model_config: Optional["PretrainedConfig"] = None,
38
- rbln_kwargs: Dict[str, Any] = {},
39
- ) -> RBLNConfig:
40
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
41
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
42
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
43
-
44
- max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
45
-
46
- if rbln_max_seq_len is None:
47
- rbln_max_seq_len = max_position_embeddings
48
- if rbln_max_seq_len is None:
49
- for tokenizer in preprocessors:
50
- if hasattr(tokenizer, "model_max_length"):
51
- rbln_max_seq_len = tokenizer.model_max_length
52
- break
53
- if rbln_max_seq_len is None:
54
- raise ValueError("`rbln_max_seq_len` should be specified!")
55
-
56
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
57
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
58
-
59
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
60
-
61
- if rbln_model_input_names is None:
62
- for tokenizer in preprocessors:
63
- if hasattr(tokenizer, "model_input_names"):
64
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
65
-
66
- invalid_params = set(rbln_model_input_names) - set(signature_params)
67
- if invalid_params:
68
- raise ValueError(f"Invalid model input names: {invalid_params}")
69
- break
70
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
71
- rbln_model_input_names = cls.rbln_model_input_names
72
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
73
- raise ValueError(
74
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
75
- f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
76
- )
77
- else:
78
- invalid_params = set(rbln_model_input_names) - set(signature_params)
79
- if invalid_params:
80
- raise ValueError(f"Invalid model input names: {invalid_params}")
81
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
82
-
83
- if rbln_batch_size is None:
84
- rbln_batch_size = 1
85
-
86
- input_info = [
87
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
88
- for model_input_name in rbln_model_input_names
89
- ]
90
-
91
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
92
-
93
- rbln_config = RBLNConfig(
94
- rbln_cls=cls.__name__,
95
- compile_cfgs=[rbln_compile_config],
96
- rbln_kwargs=rbln_kwargs,
97
- )
98
25
 
99
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
100
- return rbln_config
26
+ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
27
+ rbln_model_input_names = ["input_ids", "attention_mask"]
101
28
 
102
29
 
103
30
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
@@ -12,6 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_clip import (
16
+ RBLNCLIPTextModelConfig,
17
+ RBLNCLIPTextModelWithProjectionConfig,
18
+ RBLNCLIPVisionModelConfig,
19
+ RBLNCLIPVisionModelWithProjectionConfig,
20
+ )
15
21
  from .modeling_clip import (
16
22
  RBLNCLIPTextModel,
17
23
  RBLNCLIPTextModelWithProjection,
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs):
22
+ """
23
+ Args:
24
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
25
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
26
+
27
+ Raises:
28
+ ValueError: If batch_size is not a positive integer.
29
+ """
30
+ super().__init__(**kwargs)
31
+ self.batch_size = batch_size or 1
32
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
33
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
34
+
35
+
36
+ class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
37
+ pass
38
+
39
+
40
+ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
41
+ def __init__(self, batch_size: Optional[int] = None, image_size: Optional[int] = None, **kwargs):
42
+ """
43
+ Args:
44
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
45
+ image_size (Optional[int]): The size of input images. Can be an integer for square images,
46
+ a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
47
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
48
+
49
+ Raises:
50
+ ValueError: If batch_size is not a positive integer.
51
+ """
52
+ super().__init__(**kwargs)
53
+ self.batch_size = batch_size or 1
54
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
55
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
56
+
57
+ self.image_size = image_size
58
+
59
+ @property
60
+ def image_width(self):
61
+ if isinstance(self.image_size, int):
62
+ return self.image_size
63
+ elif isinstance(self.image_size, (list, tuple)):
64
+ return self.image_size[1]
65
+ else:
66
+ return self.image_size["width"]
67
+
68
+ @property
69
+ def image_height(self):
70
+ if isinstance(self.image_size, int):
71
+ return self.image_size
72
+ elif isinstance(self.image_size, (list, tuple)):
73
+ return self.image_size[0]
74
+ else:
75
+ return self.image_size["height"]
76
+
77
+
78
+ class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
79
+ pass