optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. optimum/rbln/__init__.py +156 -36
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/configuration_utils.py +772 -0
  4. optimum/rbln/diffusers/__init__.py +56 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +30 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
  13. optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
  19. optimum/rbln/diffusers/modeling_diffusers.py +63 -122
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
  21. optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
  23. optimum/rbln/diffusers/models/controlnet.py +55 -70
  24. optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
  25. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
  26. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
  27. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
  28. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  29. optimum/rbln/modeling.py +58 -39
  30. optimum/rbln/modeling_base.py +85 -75
  31. optimum/rbln/transformers/__init__.py +79 -8
  32. optimum/rbln/transformers/configuration_alias.py +49 -0
  33. optimum/rbln/transformers/configuration_generic.py +142 -0
  34. optimum/rbln/transformers/modeling_generic.py +193 -280
  35. optimum/rbln/transformers/models/__init__.py +96 -34
  36. optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
  37. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  38. optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
  39. optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
  40. optimum/rbln/transformers/models/bert/__init__.py +1 -0
  41. optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
  43. optimum/rbln/transformers/models/clip/__init__.py +6 -0
  44. optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
  45. optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
  46. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  47. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
  48. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +50 -43
  49. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +114 -141
  50. optimum/rbln/transformers/models/dpt/__init__.py +1 -0
  51. optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
  52. optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
  53. optimum/rbln/transformers/models/exaone/__init__.py +1 -0
  54. optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
  55. optimum/rbln/transformers/models/gemma/__init__.py +1 -0
  56. optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
  57. optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
  58. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
  59. optimum/rbln/transformers/models/llama/__init__.py +1 -0
  60. optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
  61. optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
  62. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
  63. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
  64. optimum/rbln/transformers/models/midm/__init__.py +1 -0
  65. optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
  66. optimum/rbln/transformers/models/mistral/__init__.py +1 -0
  67. optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
  68. optimum/rbln/transformers/models/phi/__init__.py +1 -0
  69. optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
  70. optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
  71. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
  72. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
  73. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
  74. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
  75. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  76. optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
  77. optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
  78. optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
  79. optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
  80. optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
  81. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
  82. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
  83. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
  84. optimum/rbln/transformers/models/whisper/__init__.py +1 -0
  85. optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
  86. optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
  87. optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
  88. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
  89. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
  90. optimum/rbln/utils/submodule.py +26 -43
  91. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/METADATA +1 -1
  92. optimum_rbln-0.7.4a5.dist-info/RECORD +162 -0
  93. optimum/rbln/modeling_config.py +0 -310
  94. optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
  95. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/WHEEL +0 -0
  96. {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a5.dist-info}/licenses/LICENSE +0 -0
@@ -32,29 +32,60 @@ _import_structure = {
32
32
  "RBLNAutoModelForSpeechSeq2Seq",
33
33
  "RBLNAutoModelForVision2Seq",
34
34
  ],
35
- "bart": ["RBLNBartForConditionalGeneration", "RBLNBartModel"],
36
- "bert": ["RBLNBertModel", "RBLNBertForQuestionAnswering", "RBLNBertForMaskedLM"],
35
+ "bart": [
36
+ "RBLNBartForConditionalGeneration",
37
+ "RBLNBartModel",
38
+ "RBLNBartForConditionalGenerationConfig",
39
+ "RBLNBartModelConfig",
40
+ ],
41
+ "bert": [
42
+ "RBLNBertModel",
43
+ "RBLNBertModelConfig",
44
+ "RBLNBertForQuestionAnswering",
45
+ "RBLNBertForQuestionAnsweringConfig",
46
+ "RBLNBertForMaskedLM",
47
+ "RBLNBertForMaskedLMConfig",
48
+ ],
37
49
  "clip": [
38
50
  "RBLNCLIPTextModel",
51
+ "RBLNCLIPTextModelConfig",
39
52
  "RBLNCLIPTextModelWithProjection",
53
+ "RBLNCLIPTextModelWithProjectionConfig",
40
54
  "RBLNCLIPVisionModel",
55
+ "RBLNCLIPVisionModelConfig",
41
56
  "RBLNCLIPVisionModelWithProjection",
57
+ "RBLNCLIPVisionModelWithProjectionConfig",
58
+ ],
59
+ "decoderonly": [
60
+ "RBLNDecoderOnlyModelForCausalLM",
61
+ "RBLNDecoderOnlyModelForCausalLMConfig",
62
+ ],
63
+ "dpt": [
64
+ "RBLNDPTForDepthEstimation",
65
+ "RBLNDPTForDepthEstimationConfig",
66
+ ],
67
+ "exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
68
+ "gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig"],
69
+ "gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig"],
70
+ "llama": ["RBLNLlamaForCausalLM", "RBLNLlamaForCausalLMConfig"],
71
+ "llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
72
+ "midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
73
+ "mistral": ["RBLNMistralForCausalLM", "RBLNMistralForCausalLMConfig"],
74
+ "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig"],
75
+ "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig"],
76
+ "time_series_transformers": [
77
+ "RBLNTimeSeriesTransformerForPrediction",
78
+ "RBLNTimeSeriesTransformerForPredictionConfig",
79
+ ],
80
+ "t5": [
81
+ "RBLNT5EncoderModel",
82
+ "RBLNT5ForConditionalGeneration",
83
+ "RBLNT5EncoderModelConfig",
84
+ "RBLNT5ForConditionalGenerationConfig",
42
85
  ],
43
- "dpt": ["RBLNDPTForDepthEstimation"],
44
- "exaone": ["RBLNExaoneForCausalLM"],
45
- "gemma": ["RBLNGemmaForCausalLM"],
46
- "gpt2": ["RBLNGPT2LMHeadModel"],
47
- "llama": ["RBLNLlamaForCausalLM"],
48
- "llava_next": ["RBLNLlavaNextForConditionalGeneration"],
49
- "midm": ["RBLNMidmLMHeadModel"],
50
- "mistral": ["RBLNMistralForCausalLM"],
51
- "phi": ["RBLNPhiForCausalLM"],
52
- "qwen2": ["RBLNQwen2ForCausalLM"],
53
- "time_series_transformers": ["RBLNTimeSeriesTransformerForPrediction"],
54
- "t5": ["RBLNT5EncoderModel", "RBLNT5ForConditionalGeneration"],
55
- "wav2vec2": ["RBLNWav2Vec2ForCTC"],
56
- "whisper": ["RBLNWhisperForConditionalGeneration"],
57
- "xlm_roberta": ["RBLNXLMRobertaModel"],
86
+ "wav2vec2": ["RBLNWav2Vec2ForCTC", "RBLNWav2Vec2ForCTCConfig"],
87
+ "whisper": ["RBLNWhisperForConditionalGeneration", "RBLNWhisperForConditionalGenerationConfig"],
88
+ "xlm_roberta": ["RBLNXLMRobertaModel", "RBLNXLMRobertaModelConfig"],
58
89
  }
59
90
 
60
91
  if TYPE_CHECKING:
@@ -72,29 +103,60 @@ if TYPE_CHECKING:
72
103
  RBLNAutoModelForSpeechSeq2Seq,
73
104
  RBLNAutoModelForVision2Seq,
74
105
  )
75
- from .bart import RBLNBartForConditionalGeneration, RBLNBartModel
76
- from .bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
106
+ from .bart import (
107
+ RBLNBartForConditionalGeneration,
108
+ RBLNBartForConditionalGenerationConfig,
109
+ RBLNBartModel,
110
+ RBLNBartModelConfig,
111
+ )
112
+ from .bert import (
113
+ RBLNBertForMaskedLM,
114
+ RBLNBertForMaskedLMConfig,
115
+ RBLNBertForQuestionAnswering,
116
+ RBLNBertForQuestionAnsweringConfig,
117
+ RBLNBertModel,
118
+ RBLNBertModelConfig,
119
+ )
77
120
  from .clip import (
78
121
  RBLNCLIPTextModel,
122
+ RBLNCLIPTextModelConfig,
79
123
  RBLNCLIPTextModelWithProjection,
124
+ RBLNCLIPTextModelWithProjectionConfig,
80
125
  RBLNCLIPVisionModel,
126
+ RBLNCLIPVisionModelConfig,
81
127
  RBLNCLIPVisionModelWithProjection,
128
+ RBLNCLIPVisionModelWithProjectionConfig,
129
+ )
130
+ from .decoderonly import (
131
+ RBLNDecoderOnlyModelForCausalLM,
132
+ RBLNDecoderOnlyModelForCausalLMConfig,
133
+ )
134
+ from .dpt import (
135
+ RBLNDPTForDepthEstimation,
136
+ RBLNDPTForDepthEstimationConfig,
137
+ )
138
+ from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
139
+ from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig
140
+ from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig
141
+ from .llama import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
142
+ from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
143
+ from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
144
+ from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
145
+ from .phi import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig
146
+ from .qwen2 import RBLNQwen2ForCausalLM, RBLNQwen2ForCausalLMConfig
147
+ from .t5 import (
148
+ RBLNT5EncoderModel,
149
+ RBLNT5EncoderModelConfig,
150
+ RBLNT5ForConditionalGeneration,
151
+ RBLNT5ForConditionalGenerationConfig,
152
+ )
153
+ from .time_series_transformers import (
154
+ RBLNTimeSeriesTransformerForPrediction,
155
+ RBLNTimeSeriesTransformerForPredictionConfig,
82
156
  )
83
- from .dpt import RBLNDPTForDepthEstimation
84
- from .exaone import RBLNExaoneForCausalLM
85
- from .gemma import RBLNGemmaForCausalLM
86
- from .gpt2 import RBLNGPT2LMHeadModel
87
- from .llama import RBLNLlamaForCausalLM
88
- from .llava_next import RBLNLlavaNextForConditionalGeneration
89
- from .midm import RBLNMidmLMHeadModel
90
- from .mistral import RBLNMistralForCausalLM
91
- from .phi import RBLNPhiForCausalLM
92
- from .qwen2 import RBLNQwen2ForCausalLM
93
- from .t5 import RBLNT5EncoderModel, RBLNT5ForConditionalGeneration
94
- from .time_series_transformers import RBLNTimeSeriesTransformerForPrediction
95
- from .wav2vec2 import RBLNWav2Vec2ForCTC
96
- from .whisper import RBLNWhisperForConditionalGeneration
97
- from .xlm_roberta import RBLNXLMRobertaModel
157
+ from .wav2vec2 import RBLNWav2Vec2ForCTC, RBLNWav2Vec2ForCTCConfig
158
+ from .whisper import RBLNWhisperForConditionalGeneration, RBLNWhisperForConditionalGenerationConfig
159
+ from .xlm_roberta import RBLNXLMRobertaModel, RBLNXLMRobertaModelConfig
98
160
 
99
161
  else:
100
162
  import sys
@@ -20,8 +20,8 @@ from transformers import AutoConfig, PretrainedConfig
20
20
  from transformers.dynamic_module_utils import get_class_from_dynamic_module
21
21
  from transformers.models.auto.auto_factory import _get_model_class
22
22
 
23
+ from optimum.rbln.configuration_utils import RBLNAutoConfig
23
24
  from optimum.rbln.modeling_base import RBLNBaseModel
24
- from optimum.rbln.modeling_config import RBLNConfig
25
25
  from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
26
26
 
27
27
 
@@ -154,9 +154,9 @@ class _BaseAutoModelClass:
154
154
  model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
155
155
  model_id=pretrained_model_name_or_path, **filtered_kwargs
156
156
  )
157
- rbln_config = RBLNConfig.load(model_path_subfolder)
157
+ rbln_config = RBLNAutoConfig.load(model_path_subfolder)
158
158
 
159
- return rbln_config.meta["cls"]
159
+ return rbln_config.rbln_model_cls_name
160
160
 
161
161
  @classmethod
162
162
  def from_pretrained(
@@ -13,4 +13,5 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_bart import RBLNBartForConditionalGenerationConfig, RBLNBartModelConfig
16
17
  from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ pass
21
+
22
+
23
+ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
+ pass
@@ -13,110 +13,36 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
16
+ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
- from transformers import BartForConditionalGeneration, PretrainedConfig, PreTrainedModel
18
+ from transformers import BartForConditionalGeneration, PreTrainedModel
19
19
 
20
- from ....modeling import RBLNModel
21
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
22
20
  from ....utils.logging import get_logger
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
23
22
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
24
23
  from .bart_architecture import BartWrapper
24
+ from .configuration_bart import RBLNBartForConditionalGenerationConfig
25
25
 
26
26
 
27
27
  logger = get_logger()
28
28
 
29
29
 
30
30
  if TYPE_CHECKING:
31
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
31
+ from transformers import PreTrainedModel
32
32
 
33
33
 
34
- class RBLNBartModel(RBLNModel):
35
- @classmethod
36
- def _get_rbln_config(
37
- cls,
38
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
39
- model_config: Optional["PretrainedConfig"] = None,
40
- rbln_kwargs: Dict[str, Any] = {},
41
- ) -> RBLNConfig:
42
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
43
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
44
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
45
-
46
- max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
47
-
48
- if rbln_max_seq_len is None:
49
- rbln_max_seq_len = max_position_embeddings
50
- if rbln_max_seq_len is None:
51
- for tokenizer in preprocessors:
52
- if hasattr(tokenizer, "model_max_length"):
53
- rbln_max_seq_len = tokenizer.model_max_length
54
- break
55
- if rbln_max_seq_len is None:
56
- raise ValueError("`rbln_max_seq_len` should be specified!")
57
-
58
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
59
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
60
-
61
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
62
-
63
- if rbln_model_input_names is None:
64
- for tokenizer in preprocessors:
65
- if hasattr(tokenizer, "model_input_names"):
66
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
67
- # BartModel's forward() does not take token_type_ids as input.
68
- # (Added because some of the tokenizers includes 'token_type_ids')
69
- if "token_type_ids" in rbln_model_input_names:
70
- rbln_model_input_names.remove("token_type_ids")
71
-
72
- invalid_params = set(rbln_model_input_names) - set(signature_params)
73
- if invalid_params:
74
- raise ValueError(f"Invalid model input names: {invalid_params}")
75
- break
76
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
77
- rbln_model_input_names = cls.rbln_model_input_names
78
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
79
- raise ValueError(
80
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
81
- f"and be sure to make the order of the inputs same as BartModel forward() arguments like ({list(signature_params)})"
82
- )
83
- else:
84
- invalid_params = set(rbln_model_input_names) - set(signature_params)
85
- if invalid_params:
86
- raise ValueError(f"Invalid model input names: {invalid_params}")
87
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
88
-
89
- if rbln_batch_size is None:
90
- rbln_batch_size = 1
91
-
92
- input_info = [
93
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
94
- for model_input_name in rbln_model_input_names
95
- ]
96
-
97
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
98
-
99
- rbln_config = RBLNConfig(
100
- rbln_cls=cls.__name__,
101
- compile_cfgs=[rbln_compile_config],
102
- rbln_kwargs=rbln_kwargs,
103
- )
104
-
105
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
106
- return rbln_config
34
+ class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
+ pass
107
36
 
108
37
 
109
38
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
110
39
  support_causal_attn = True
111
40
 
112
41
  @classmethod
113
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
114
- enc_max_seq_len = (
115
- rbln_config.model_cfg["enc_max_seq_len"] if "enc_max_seq_len" in rbln_config.model_cfg else 1024
42
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNBartForConditionalGenerationConfig):
43
+ return BartWrapper(
44
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
116
45
  )
117
- use_attention_mask = rbln_config.model_cfg.get("use_attention_mask", False)
118
-
119
- return BartWrapper(model, enc_max_seq_len=enc_max_seq_len, use_attention_mask=use_attention_mask)
120
46
 
121
47
  def __getattr__(self, __name: str) -> Any:
122
48
  def redirect(func):
@@ -12,4 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_bert import RBLNBertForMaskedLMConfig, RBLNBertForQuestionAnsweringConfig, RBLNBertModelConfig
15
16
  from .modeling_bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import (
16
+ RBLNModelForMaskedLMConfig,
17
+ RBLNModelForQuestionAnsweringConfig,
18
+ RBLNTransformerEncoderForFeatureExtractionConfig,
19
+ )
20
+
21
+
22
+ class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
23
+ pass
24
+
25
+
26
+ class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
27
+ pass
28
+
29
+
30
+ class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
31
+ pass
@@ -12,92 +12,19 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import inspect
16
- from typing import TYPE_CHECKING, Any, Dict, Optional, Union
17
-
18
- from transformers import PretrainedConfig
19
-
20
- from ....modeling import RBLNModel
21
- from ....modeling_config import RBLNCompileConfig, RBLNConfig
22
15
  from ....utils.logging import get_logger
23
- from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForQuestionAnswering
16
+ from ...modeling_generic import (
17
+ RBLNModelForMaskedLM,
18
+ RBLNModelForQuestionAnswering,
19
+ RBLNTransformerEncoderForFeatureExtraction,
20
+ )
24
21
 
25
22
 
26
23
  logger = get_logger(__name__)
27
24
 
28
- if TYPE_CHECKING:
29
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
30
-
31
-
32
- class RBLNBertModel(RBLNModel):
33
- @classmethod
34
- def _get_rbln_config(
35
- cls,
36
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
37
- model_config: Optional["PretrainedConfig"] = None,
38
- rbln_kwargs: Dict[str, Any] = {},
39
- ) -> RBLNConfig:
40
- rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
41
- rbln_batch_size = rbln_kwargs.get("batch_size", None)
42
- rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
43
-
44
- max_position_embeddings = getattr(model_config, "max_position_embeddings", None)
45
-
46
- if rbln_max_seq_len is None:
47
- rbln_max_seq_len = max_position_embeddings
48
- if rbln_max_seq_len is None:
49
- for tokenizer in preprocessors:
50
- if hasattr(tokenizer, "model_max_length"):
51
- rbln_max_seq_len = tokenizer.model_max_length
52
- break
53
- if rbln_max_seq_len is None:
54
- raise ValueError("`rbln_max_seq_len` should be specified!")
55
-
56
- if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
57
- raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
58
-
59
- signature_params = inspect.signature(cls.get_hf_class().forward).parameters.keys()
60
-
61
- if rbln_model_input_names is None:
62
- for tokenizer in preprocessors:
63
- if hasattr(tokenizer, "model_input_names"):
64
- rbln_model_input_names = [name for name in signature_params if name in tokenizer.model_input_names]
65
-
66
- invalid_params = set(rbln_model_input_names) - set(signature_params)
67
- if invalid_params:
68
- raise ValueError(f"Invalid model input names: {invalid_params}")
69
- break
70
- if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
71
- rbln_model_input_names = cls.rbln_model_input_names
72
- elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
73
- raise ValueError(
74
- "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
75
- f"and be sure to make the order of the inputs same as BertModel forward() arguments like ({list(signature_params)})"
76
- )
77
- else:
78
- invalid_params = set(rbln_model_input_names) - set(signature_params)
79
- if invalid_params:
80
- raise ValueError(f"Invalid model input names: {invalid_params}")
81
- rbln_model_input_names = [name for name in signature_params if name in rbln_model_input_names]
82
-
83
- if rbln_batch_size is None:
84
- rbln_batch_size = 1
85
-
86
- input_info = [
87
- (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
88
- for model_input_name in rbln_model_input_names
89
- ]
90
-
91
- rbln_compile_config = RBLNCompileConfig(input_info=input_info)
92
-
93
- rbln_config = RBLNConfig(
94
- rbln_cls=cls.__name__,
95
- compile_cfgs=[rbln_compile_config],
96
- rbln_kwargs=rbln_kwargs,
97
- )
98
25
 
99
- rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
100
- return rbln_config
26
+ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
27
+ rbln_model_input_names = ["input_ids", "attention_mask"]
101
28
 
102
29
 
103
30
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
@@ -12,6 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from .configuration_clip import (
16
+ RBLNCLIPTextModelConfig,
17
+ RBLNCLIPTextModelWithProjectionConfig,
18
+ RBLNCLIPVisionModelConfig,
19
+ RBLNCLIPVisionModelWithProjectionConfig,
20
+ )
15
21
  from .modeling_clip import (
16
22
  RBLNCLIPTextModel,
17
23
  RBLNCLIPTextModelWithProjection,
@@ -0,0 +1,79 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs):
22
+ """
23
+ Args:
24
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
25
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
26
+
27
+ Raises:
28
+ ValueError: If batch_size is not a positive integer.
29
+ """
30
+ super().__init__(**kwargs)
31
+ self.batch_size = batch_size or 1
32
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
33
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
34
+
35
+
36
+ class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
37
+ pass
38
+
39
+
40
+ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
41
+ def __init__(self, batch_size: Optional[int] = None, image_size: Optional[int] = None, **kwargs):
42
+ """
43
+ Args:
44
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
45
+ image_size (Optional[int]): The size of input images. Can be an integer for square images,
46
+ a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
47
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
48
+
49
+ Raises:
50
+ ValueError: If batch_size is not a positive integer.
51
+ """
52
+ super().__init__(**kwargs)
53
+ self.batch_size = batch_size or 1
54
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
55
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
56
+
57
+ self.image_size = image_size
58
+
59
+ @property
60
+ def image_width(self):
61
+ if isinstance(self.image_size, int):
62
+ return self.image_size
63
+ elif isinstance(self.image_size, (list, tuple)):
64
+ return self.image_size[1]
65
+ else:
66
+ return self.image_size["width"]
67
+
68
+ @property
69
+ def image_height(self):
70
+ if isinstance(self.image_size, int):
71
+ return self.image_size
72
+ elif isinstance(self.image_size, (list, tuple)):
73
+ return self.image_size[0]
74
+ else:
75
+ return self.image_size["height"]
76
+
77
+
78
+ class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
79
+ pass