optimum-rbln 0.8.3a4__py3-none-any.whl → 0.8.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/modeling.py +1 -0
- optimum/rbln/modeling_base.py +24 -7
- optimum/rbln/transformers/__init__.py +14 -0
- optimum/rbln/transformers/configuration_generic.py +2 -0
- optimum/rbln/transformers/modeling_generic.py +12 -4
- optimum/rbln/transformers/models/__init__.py +18 -0
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +6 -1
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +1 -1
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +6 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +86 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +507 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +32 -7
- optimum/rbln/utils/submodule.py +10 -4
- {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.3rc0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.3rc0.dist-info}/RECORD +25 -20
- {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.3a4.dist-info → optimum_rbln-0.8.3rc0.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/__init__.py
CHANGED
|
@@ -47,6 +47,7 @@ _import_structure = {
|
|
|
47
47
|
"RBLNAutoModelForSpeechSeq2Seq",
|
|
48
48
|
"RBLNAutoModelForVision2Seq",
|
|
49
49
|
"RBLNAutoModelForTextEncoding",
|
|
50
|
+
"RBLNAutoModelForZeroShotObjectDetection",
|
|
50
51
|
"RBLNBartForConditionalGeneration",
|
|
51
52
|
"RBLNBartForConditionalGenerationConfig",
|
|
52
53
|
"RBLNBartModel",
|
|
@@ -97,6 +98,12 @@ _import_structure = {
|
|
|
97
98
|
"RBLNGPT2ModelConfig",
|
|
98
99
|
"RBLNGPT2LMHeadModel",
|
|
99
100
|
"RBLNGPT2LMHeadModelConfig",
|
|
101
|
+
"RBLNGroundingDinoDecoder",
|
|
102
|
+
"RBLNGroundingDinoDecoderConfig",
|
|
103
|
+
"RBLNGroundingDinoForObjectDetection",
|
|
104
|
+
"RBLNGroundingDinoForObjectDetectionConfig",
|
|
105
|
+
"RBLNGroundingDinoEncoder",
|
|
106
|
+
"RBLNGroundingDinoEncoderConfig",
|
|
100
107
|
"RBLNIdefics3VisionTransformer",
|
|
101
108
|
"RBLNIdefics3ForConditionalGeneration",
|
|
102
109
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
@@ -326,6 +333,7 @@ if TYPE_CHECKING:
|
|
|
326
333
|
RBLNAutoModelForSpeechSeq2Seq,
|
|
327
334
|
RBLNAutoModelForTextEncoding,
|
|
328
335
|
RBLNAutoModelForVision2Seq,
|
|
336
|
+
RBLNAutoModelForZeroShotObjectDetection,
|
|
329
337
|
RBLNBartForConditionalGeneration,
|
|
330
338
|
RBLNBartForConditionalGenerationConfig,
|
|
331
339
|
RBLNBartModel,
|
|
@@ -376,6 +384,12 @@ if TYPE_CHECKING:
|
|
|
376
384
|
RBLNGPT2LMHeadModelConfig,
|
|
377
385
|
RBLNGPT2Model,
|
|
378
386
|
RBLNGPT2ModelConfig,
|
|
387
|
+
RBLNGroundingDinoDecoder,
|
|
388
|
+
RBLNGroundingDinoDecoderConfig,
|
|
389
|
+
RBLNGroundingDinoEncoder,
|
|
390
|
+
RBLNGroundingDinoEncoderConfig,
|
|
391
|
+
RBLNGroundingDinoForObjectDetection,
|
|
392
|
+
RBLNGroundingDinoForObjectDetectionConfig,
|
|
379
393
|
RBLNIdefics3ForConditionalGeneration,
|
|
380
394
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
381
395
|
RBLNIdefics3VisionTransformer,
|
optimum/rbln/__version__.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.8.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 8, 3, '
|
|
31
|
+
__version__ = version = '0.8.3rc0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 8, 3, 'rc0')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
optimum/rbln/modeling.py
CHANGED
optimum/rbln/modeling_base.py
CHANGED
|
@@ -525,13 +525,30 @@ class RBLNBaseModel(SubModulesMixin, PushToHubMixin, PreTrainedModel):
|
|
|
525
525
|
|
|
526
526
|
# If everything succeeded, move files to target directory
|
|
527
527
|
if os.path.exists(save_directory_path):
|
|
528
|
-
#
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
528
|
+
# Merge files from tmp_dir into existing directory
|
|
529
|
+
def _merge_dir(src_root: str, dst_root: str):
|
|
530
|
+
for name in os.listdir(src_root):
|
|
531
|
+
src_item = os.path.join(src_root, name)
|
|
532
|
+
dst_item = os.path.join(dst_root, name)
|
|
533
|
+
|
|
534
|
+
if os.path.islink(src_item) or os.path.isfile(src_item):
|
|
535
|
+
os.makedirs(os.path.dirname(dst_item), exist_ok=True)
|
|
536
|
+
if os.path.isdir(dst_item) and not os.path.islink(dst_item):
|
|
537
|
+
shutil.rmtree(dst_item)
|
|
538
|
+
os.replace(src_item, dst_item)
|
|
539
|
+
elif os.path.isdir(src_item):
|
|
540
|
+
if os.path.islink(dst_item) or os.path.isfile(dst_item):
|
|
541
|
+
os.remove(dst_item)
|
|
542
|
+
os.makedirs(dst_item, exist_ok=True)
|
|
543
|
+
_merge_dir(src_item, dst_item)
|
|
544
|
+
else:
|
|
545
|
+
# Fallback for special file types
|
|
546
|
+
os.replace(src_item, dst_item)
|
|
547
|
+
|
|
548
|
+
_merge_dir(tmp_dir, str(save_directory_path))
|
|
549
|
+
|
|
550
|
+
# Remove the temporary directory tree after merge
|
|
551
|
+
shutil.rmtree(tmp_dir)
|
|
535
552
|
else:
|
|
536
553
|
# If target doesn't exist, just rename tmp_dir to target
|
|
537
554
|
os.rename(tmp_dir, save_directory_path)
|
|
@@ -35,6 +35,7 @@ _import_structure = {
|
|
|
35
35
|
"RBLNAutoModelForSpeechSeq2Seq",
|
|
36
36
|
"RBLNAutoModelForVision2Seq",
|
|
37
37
|
"RBLNAutoModelForTextEncoding",
|
|
38
|
+
"RBLNAutoModelForZeroShotObjectDetection",
|
|
38
39
|
"RBLNBartForConditionalGeneration",
|
|
39
40
|
"RBLNBartForConditionalGenerationConfig",
|
|
40
41
|
"RBLNBartModel",
|
|
@@ -85,6 +86,12 @@ _import_structure = {
|
|
|
85
86
|
"RBLNGPT2LMHeadModelConfig",
|
|
86
87
|
"RBLNGPT2Model",
|
|
87
88
|
"RBLNGPT2ModelConfig",
|
|
89
|
+
"RBLNGroundingDinoDecoder",
|
|
90
|
+
"RBLNGroundingDinoDecoderConfig",
|
|
91
|
+
"RBLNGroundingDinoForObjectDetection",
|
|
92
|
+
"RBLNGroundingDinoForObjectDetectionConfig",
|
|
93
|
+
"RBLNGroundingDinoEncoder",
|
|
94
|
+
"RBLNGroundingDinoEncoderConfig",
|
|
88
95
|
"RBLNIdefics3ForConditionalGeneration",
|
|
89
96
|
"RBLNIdefics3ForConditionalGenerationConfig",
|
|
90
97
|
"RBLNIdefics3VisionTransformer",
|
|
@@ -178,6 +185,7 @@ if TYPE_CHECKING:
|
|
|
178
185
|
RBLNAutoModelForSpeechSeq2Seq,
|
|
179
186
|
RBLNAutoModelForTextEncoding,
|
|
180
187
|
RBLNAutoModelForVision2Seq,
|
|
188
|
+
RBLNAutoModelForZeroShotObjectDetection,
|
|
181
189
|
RBLNBartForConditionalGeneration,
|
|
182
190
|
RBLNBartForConditionalGenerationConfig,
|
|
183
191
|
RBLNBartModel,
|
|
@@ -228,6 +236,12 @@ if TYPE_CHECKING:
|
|
|
228
236
|
RBLNGPT2LMHeadModelConfig,
|
|
229
237
|
RBLNGPT2Model,
|
|
230
238
|
RBLNGPT2ModelConfig,
|
|
239
|
+
RBLNGroundingDinoDecoder,
|
|
240
|
+
RBLNGroundingDinoDecoderConfig,
|
|
241
|
+
RBLNGroundingDinoEncoder,
|
|
242
|
+
RBLNGroundingDinoEncoderConfig,
|
|
243
|
+
RBLNGroundingDinoForObjectDetection,
|
|
244
|
+
RBLNGroundingDinoForObjectDetectionConfig,
|
|
231
245
|
RBLNIdefics3ForConditionalGeneration,
|
|
232
246
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
233
247
|
RBLNIdefics3VisionTransformer,
|
|
@@ -25,6 +25,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
|
25
25
|
max_seq_len: Optional[int] = None,
|
|
26
26
|
batch_size: Optional[int] = None,
|
|
27
27
|
model_input_names: Optional[List[str]] = None,
|
|
28
|
+
model_input_shapes: Optional[List[Tuple[int, int]]] = None,
|
|
28
29
|
**kwargs: Any,
|
|
29
30
|
):
|
|
30
31
|
"""
|
|
@@ -45,6 +46,7 @@ class RBLNTransformerEncoderConfig(RBLNModelConfig):
|
|
|
45
46
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
46
47
|
|
|
47
48
|
self.model_input_names = model_input_names or self.rbln_model_input_names
|
|
49
|
+
self.model_input_shapes = model_input_shapes
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
class RBLNImageModelConfig(RBLNModelConfig):
|
|
@@ -127,10 +127,18 @@ class RBLNTransformerEncoder(RBLNModel):
|
|
|
127
127
|
"This is an internal error. Please report it to the developers."
|
|
128
128
|
)
|
|
129
129
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
130
|
+
if rbln_config.model_input_shapes is None:
|
|
131
|
+
input_info = [
|
|
132
|
+
(model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
|
|
133
|
+
for model_input_name in rbln_config.model_input_names
|
|
134
|
+
]
|
|
135
|
+
else:
|
|
136
|
+
input_info = [
|
|
137
|
+
(model_input_name, model_input_shape, cls.rbln_dtype)
|
|
138
|
+
for model_input_name, model_input_shape in zip(
|
|
139
|
+
rbln_config.model_input_names, rbln_config.model_input_shapes
|
|
140
|
+
)
|
|
141
|
+
]
|
|
134
142
|
|
|
135
143
|
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
136
144
|
return rbln_config
|
|
@@ -37,6 +37,7 @@ _import_structure = {
|
|
|
37
37
|
"RBLNAutoModelForVision2Seq",
|
|
38
38
|
"RBLNAutoModelForImageTextToText",
|
|
39
39
|
"RBLNAutoModelForTextEncoding",
|
|
40
|
+
"RBLNAutoModelForZeroShotObjectDetection",
|
|
40
41
|
],
|
|
41
42
|
"bart": [
|
|
42
43
|
"RBLNBartForConditionalGeneration",
|
|
@@ -165,6 +166,14 @@ _import_structure = {
|
|
|
165
166
|
"RBLNXLMRobertaForSequenceClassification",
|
|
166
167
|
"RBLNXLMRobertaForSequenceClassificationConfig",
|
|
167
168
|
],
|
|
169
|
+
"grounding_dino": [
|
|
170
|
+
"RBLNGroundingDinoForObjectDetection",
|
|
171
|
+
"RBLNGroundingDinoForObjectDetectionConfig",
|
|
172
|
+
"RBLNGroundingDinoEncoder",
|
|
173
|
+
"RBLNGroundingDinoEncoderConfig",
|
|
174
|
+
"RBLNGroundingDinoDecoder",
|
|
175
|
+
"RBLNGroundingDinoDecoderConfig",
|
|
176
|
+
],
|
|
168
177
|
}
|
|
169
178
|
|
|
170
179
|
if TYPE_CHECKING:
|
|
@@ -184,6 +193,7 @@ if TYPE_CHECKING:
|
|
|
184
193
|
RBLNAutoModelForSpeechSeq2Seq,
|
|
185
194
|
RBLNAutoModelForTextEncoding,
|
|
186
195
|
RBLNAutoModelForVision2Seq,
|
|
196
|
+
RBLNAutoModelForZeroShotObjectDetection,
|
|
187
197
|
)
|
|
188
198
|
from .bart import (
|
|
189
199
|
RBLNBartForConditionalGeneration,
|
|
@@ -236,6 +246,14 @@ if TYPE_CHECKING:
|
|
|
236
246
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
237
247
|
)
|
|
238
248
|
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
|
|
249
|
+
from .grounding_dino import (
|
|
250
|
+
RBLNGroundingDinoDecoder,
|
|
251
|
+
RBLNGroundingDinoDecoderConfig,
|
|
252
|
+
RBLNGroundingDinoEncoder,
|
|
253
|
+
RBLNGroundingDinoEncoderConfig,
|
|
254
|
+
RBLNGroundingDinoForObjectDetection,
|
|
255
|
+
RBLNGroundingDinoForObjectDetectionConfig,
|
|
256
|
+
)
|
|
239
257
|
from .idefics3 import (
|
|
240
258
|
RBLNIdefics3ForConditionalGeneration,
|
|
241
259
|
RBLNIdefics3ForConditionalGenerationConfig,
|
|
@@ -39,6 +39,8 @@ from transformers.models.auto.modeling_auto import (
|
|
|
39
39
|
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
|
|
40
40
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
|
41
41
|
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
|
42
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
|
43
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
|
|
42
44
|
MODEL_MAPPING,
|
|
43
45
|
MODEL_MAPPING_NAMES,
|
|
44
46
|
)
|
|
@@ -122,3 +124,8 @@ class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
|
|
122
124
|
class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
|
|
123
125
|
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
|
124
126
|
_model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
|
|
130
|
+
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
|
131
|
+
_model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BertModelWrapper(torch.nn.Module):
|
|
5
|
+
def __init__(self, model, rbln_config):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.model = model
|
|
8
|
+
self.rbln_config = rbln_config
|
|
9
|
+
|
|
10
|
+
def forward(self, *args, **kwargs):
|
|
11
|
+
output = self.model(*args, **kwargs)
|
|
12
|
+
if isinstance(output, torch.Tensor):
|
|
13
|
+
return output
|
|
14
|
+
elif isinstance(output, tuple):
|
|
15
|
+
return tuple(x for x in output if x is not None)
|
|
16
|
+
return output
|
|
@@ -12,15 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
import torch
|
|
16
|
+
|
|
16
17
|
from ...modeling_generic import (
|
|
17
18
|
RBLNModelForMaskedLM,
|
|
18
19
|
RBLNModelForQuestionAnswering,
|
|
19
20
|
RBLNTransformerEncoderForFeatureExtraction,
|
|
20
21
|
)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
logger = get_logger(__name__)
|
|
22
|
+
from .bert_architecture import BertModelWrapper
|
|
23
|
+
from .configuration_bert import RBLNBertModelConfig
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
@@ -34,6 +34,10 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
34
34
|
|
|
35
35
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
36
36
|
|
|
37
|
+
@classmethod
|
|
38
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
|
|
39
|
+
return BertModelWrapper(model, rbln_config)
|
|
40
|
+
|
|
37
41
|
|
|
38
42
|
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
39
43
|
"""
|
|
@@ -174,7 +174,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
174
174
|
return Blip2QFormerModelWrapper(model).eval()
|
|
175
175
|
|
|
176
176
|
@classmethod
|
|
177
|
-
def _update_submodule_config(
|
|
177
|
+
def _update_submodule_config(
|
|
178
|
+
cls,
|
|
179
|
+
model: "PreTrainedModel",
|
|
180
|
+
rbln_config: RBLNModelConfig,
|
|
181
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
182
|
+
):
|
|
178
183
|
if rbln_config.num_query_tokens is None:
|
|
179
184
|
rbln_config.num_query_tokens = model.config.num_query_tokens
|
|
180
185
|
|
|
@@ -20,6 +20,6 @@ class RBLNDepthAnythingForDepthEstimation(RBLNModelForDepthEstimation):
|
|
|
20
20
|
"""
|
|
21
21
|
RBLN optimized DepthAnythingForDepthEstimation model for depth estimation tasks.
|
|
22
22
|
|
|
23
|
-
This class provides hardware-accelerated inference for Depth Anything V2
|
|
23
|
+
This class provides hardware-accelerated inference for Depth Anything V2
|
|
24
24
|
models on RBLN devices, providing the most capable monocular depth estimation (MDE) model.
|
|
25
25
|
"""
|
|
@@ -403,7 +403,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
403
403
|
return rbln_config
|
|
404
404
|
|
|
405
405
|
@classmethod
|
|
406
|
-
def _update_submodule_config(
|
|
406
|
+
def _update_submodule_config(
|
|
407
|
+
cls,
|
|
408
|
+
model: "PreTrainedModel",
|
|
409
|
+
rbln_config: RBLNModelConfig,
|
|
410
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
411
|
+
):
|
|
407
412
|
if rbln_config.image_prefill_chunk_size is None:
|
|
408
413
|
rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
|
|
409
414
|
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .configuration_grounding_dino import (
|
|
2
|
+
RBLNGroundingDinoDecoderConfig,
|
|
3
|
+
RBLNGroundingDinoEncoderConfig,
|
|
4
|
+
RBLNGroundingDinoForObjectDetectionConfig,
|
|
5
|
+
)
|
|
6
|
+
from .modeling_grounding_dino import (
|
|
7
|
+
RBLNGroundingDinoDecoder,
|
|
8
|
+
RBLNGroundingDinoEncoder,
|
|
9
|
+
RBLNGroundingDinoForObjectDetection,
|
|
10
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at:
|
|
4
|
+
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from ...configuration_generic import RBLNImageModelConfig, RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
|
|
21
|
+
submodules = [
|
|
22
|
+
"text_backbone",
|
|
23
|
+
"backbone",
|
|
24
|
+
"encoder",
|
|
25
|
+
"decoder",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
batch_size: Optional[int] = None,
|
|
31
|
+
encoder: Optional["RBLNGroundingDinoEncoderConfig"] = None,
|
|
32
|
+
decoder: Optional["RBLNGroundingDinoDecoderConfig"] = None,
|
|
33
|
+
text_backbone: Optional["RBLNModelConfig"] = None,
|
|
34
|
+
backbone: Optional["RBLNModelConfig"] = None,
|
|
35
|
+
output_attentions: Optional[bool] = False,
|
|
36
|
+
output_hidden_states: Optional[bool] = False,
|
|
37
|
+
**kwargs: Any,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Args:
|
|
41
|
+
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
|
42
|
+
**kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ValueError: If batch_size is not a positive integer.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(**kwargs)
|
|
48
|
+
self.encoder = encoder
|
|
49
|
+
self.decoder = decoder
|
|
50
|
+
self.text_backbone = text_backbone
|
|
51
|
+
self.backbone = backbone
|
|
52
|
+
self.output_attentions = output_attentions
|
|
53
|
+
self.output_hidden_states = output_hidden_states
|
|
54
|
+
|
|
55
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
56
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class RBLNGroundingDinoComponentConfig(RBLNImageModelConfig):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
63
|
+
batch_size: Optional[int] = None,
|
|
64
|
+
spatial_shapes_list: Optional[List[Tuple[int, int]]] = None,
|
|
65
|
+
output_attentions: Optional[bool] = False,
|
|
66
|
+
output_hidden_states: Optional[bool] = False,
|
|
67
|
+
**kwargs: Any,
|
|
68
|
+
):
|
|
69
|
+
super().__init__(image_size=image_size, batch_size=batch_size, **kwargs)
|
|
70
|
+
self.spatial_shapes_list = spatial_shapes_list
|
|
71
|
+
self.output_attentions = output_attentions
|
|
72
|
+
self.output_hidden_states = output_hidden_states
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def spatial_shapes(self):
|
|
76
|
+
if self.spatial_shapes_list is None:
|
|
77
|
+
raise ValueError("Spatial shapes are not defined. Please set them before accessing.")
|
|
78
|
+
return torch.tensor(self.spatial_shapes_list)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RBLNGroundingDinoEncoderConfig(RBLNGroundingDinoComponentConfig):
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RBLNGroundingDinoDecoderConfig(RBLNGroundingDinoComponentConfig):
|
|
86
|
+
pass
|