inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,807 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import copy
|
|
3
|
+
import math
|
|
4
|
+
from typing import Callable, List, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import torchvision
|
|
9
|
+
from pydantic import BaseModel, ConfigDict
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
|
|
12
|
+
from inference_models.models.rfdetr.backbone_builder import build_backbone
|
|
13
|
+
from inference_models.models.rfdetr.misc import NestedTensor
|
|
14
|
+
from inference_models.models.rfdetr.segmentation_head import SegmentationHead
|
|
15
|
+
from inference_models.models.rfdetr.transformer import build_transformer
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelConfig(BaseModel):
|
|
19
|
+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"]
|
|
20
|
+
out_feature_indexes: List[int]
|
|
21
|
+
dec_layers: int
|
|
22
|
+
two_stage: bool = True
|
|
23
|
+
projector_scale: List[Literal["P3", "P4", "P5"]]
|
|
24
|
+
hidden_dim: int
|
|
25
|
+
patch_size: int
|
|
26
|
+
num_windows: int
|
|
27
|
+
sa_nheads: int
|
|
28
|
+
ca_nheads: int
|
|
29
|
+
dec_n_points: int
|
|
30
|
+
bbox_reparam: bool = True
|
|
31
|
+
lite_refpoint_refine: bool = True
|
|
32
|
+
layer_norm: bool = True
|
|
33
|
+
amp: bool = True
|
|
34
|
+
num_classes: int = 90
|
|
35
|
+
pretrain_weights: Optional[str] = None
|
|
36
|
+
device: torch.device
|
|
37
|
+
resolution: int
|
|
38
|
+
group_detr: int = 13
|
|
39
|
+
gradient_checkpointing: bool = False
|
|
40
|
+
positional_encoding_size: int
|
|
41
|
+
ia_bce_loss: bool = True
|
|
42
|
+
cls_loss_coef: float = 1.0
|
|
43
|
+
segmentation_head: bool = False
|
|
44
|
+
mask_downsample_ratio: int = 4
|
|
45
|
+
|
|
46
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RFDETRBaseConfig(ModelConfig):
|
|
50
|
+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
|
|
51
|
+
"dinov2_windowed_small"
|
|
52
|
+
)
|
|
53
|
+
hidden_dim: int = 256
|
|
54
|
+
patch_size: int = 14
|
|
55
|
+
num_windows: int = 4
|
|
56
|
+
dec_layers: int = 3
|
|
57
|
+
sa_nheads: int = 8
|
|
58
|
+
ca_nheads: int = 16
|
|
59
|
+
dec_n_points: int = 2
|
|
60
|
+
num_queries: int = 300
|
|
61
|
+
num_select: int = 300
|
|
62
|
+
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P4"]
|
|
63
|
+
out_feature_indexes: List[int] = [2, 5, 8, 11]
|
|
64
|
+
pretrain_weights: Optional[str] = "rf-detr-base.pth"
|
|
65
|
+
resolution: int = 560
|
|
66
|
+
positional_encoding_size: int = 37
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class RFDETRLargeConfig(RFDETRBaseConfig):
|
|
70
|
+
encoder: Literal["dinov2_windowed_small", "dinov2_windowed_base"] = (
|
|
71
|
+
"dinov2_windowed_base"
|
|
72
|
+
)
|
|
73
|
+
hidden_dim: int = 384
|
|
74
|
+
sa_nheads: int = 12
|
|
75
|
+
ca_nheads: int = 24
|
|
76
|
+
dec_n_points: int = 4
|
|
77
|
+
projector_scale: List[Literal["P3", "P4", "P5"]] = ["P3", "P5"]
|
|
78
|
+
pretrain_weights: Optional[str] = "rf-detr-large.pth"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RFDETRNanoConfig(RFDETRBaseConfig):
|
|
82
|
+
out_feature_indexes: List[int] = [3, 6, 9, 12]
|
|
83
|
+
num_windows: int = 2
|
|
84
|
+
dec_layers: int = 2
|
|
85
|
+
patch_size: int = 16
|
|
86
|
+
resolution: int = 384
|
|
87
|
+
positional_encoding_size: int = 24
|
|
88
|
+
pretrain_weights: Optional[str] = "rf-detr-nano.pth"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class RFDETRSmallConfig(RFDETRBaseConfig):
|
|
92
|
+
out_feature_indexes: List[int] = [3, 6, 9, 12]
|
|
93
|
+
num_windows: int = 2
|
|
94
|
+
dec_layers: int = 3
|
|
95
|
+
patch_size: int = 16
|
|
96
|
+
resolution: int = 512
|
|
97
|
+
positional_encoding_size: int = 32
|
|
98
|
+
pretrain_weights: Optional[str] = "rf-detr-small.pth"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class RFDETRMediumConfig(RFDETRBaseConfig):
|
|
102
|
+
out_feature_indexes: List[int] = [3, 6, 9, 12]
|
|
103
|
+
num_windows: int = 2
|
|
104
|
+
dec_layers: int = 4
|
|
105
|
+
patch_size: int = 16
|
|
106
|
+
resolution: int = 576
|
|
107
|
+
positional_encoding_size: int = 36
|
|
108
|
+
pretrain_weights: Optional[str] = "rf-detr-medium.pth"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class RFDETRSegPreviewConfig(RFDETRBaseConfig):
|
|
112
|
+
segmentation_head: bool = True
|
|
113
|
+
out_feature_indexes: List[int] = [3, 6, 9, 12]
|
|
114
|
+
num_windows: int = 2
|
|
115
|
+
dec_layers: int = 4
|
|
116
|
+
patch_size: int = 12
|
|
117
|
+
resolution: int = 432
|
|
118
|
+
positional_encoding_size: int = 36
|
|
119
|
+
num_queries: int = 200
|
|
120
|
+
num_select: int = 200
|
|
121
|
+
pretrain_weights: Optional[str] = "rf-detr-seg-preview.pt"
|
|
122
|
+
num_classes: int = 90
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class LWDETR(nn.Module):
|
|
126
|
+
"""This is the Group DETR v3 module that performs object detection"""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
backbone,
|
|
131
|
+
transformer,
|
|
132
|
+
segmentation_head,
|
|
133
|
+
num_classes,
|
|
134
|
+
num_queries,
|
|
135
|
+
aux_loss=False,
|
|
136
|
+
group_detr=1,
|
|
137
|
+
two_stage=False,
|
|
138
|
+
lite_refpoint_refine=False,
|
|
139
|
+
bbox_reparam=False,
|
|
140
|
+
):
|
|
141
|
+
"""Initializes the model.
|
|
142
|
+
Parameters:
|
|
143
|
+
backbone: torch module of the backbone to be used. See backbone.py
|
|
144
|
+
transformer: torch module of the transformer architecture. See transformer.py
|
|
145
|
+
num_classes: number of object classes
|
|
146
|
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
|
147
|
+
Conditional DETR can detect in a single image. For COCO, we recommend 100 queries.
|
|
148
|
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
|
149
|
+
group_detr: Number of groups to speed detr training. Default is 1.
|
|
150
|
+
lite_refpoint_refine: TODO
|
|
151
|
+
"""
|
|
152
|
+
super().__init__()
|
|
153
|
+
self.num_queries = num_queries
|
|
154
|
+
self.transformer = transformer
|
|
155
|
+
hidden_dim = transformer.d_model
|
|
156
|
+
self.class_embed = nn.Linear(hidden_dim, num_classes)
|
|
157
|
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
|
158
|
+
self.segmentation_head = segmentation_head
|
|
159
|
+
query_dim = 4
|
|
160
|
+
self.refpoint_embed = nn.Embedding(num_queries * group_detr, query_dim)
|
|
161
|
+
self.query_feat = nn.Embedding(num_queries * group_detr, hidden_dim)
|
|
162
|
+
nn.init.constant_(self.refpoint_embed.weight.data, 0)
|
|
163
|
+
|
|
164
|
+
self.backbone = backbone
|
|
165
|
+
self.aux_loss = aux_loss
|
|
166
|
+
self.group_detr = group_detr
|
|
167
|
+
|
|
168
|
+
# iter update
|
|
169
|
+
self.lite_refpoint_refine = lite_refpoint_refine
|
|
170
|
+
if not self.lite_refpoint_refine:
|
|
171
|
+
self.transformer.decoder.bbox_embed = self.bbox_embed
|
|
172
|
+
else:
|
|
173
|
+
self.transformer.decoder.bbox_embed = None
|
|
174
|
+
|
|
175
|
+
self.bbox_reparam = bbox_reparam
|
|
176
|
+
|
|
177
|
+
# init prior_prob setting for focal loss
|
|
178
|
+
prior_prob = 0.01
|
|
179
|
+
bias_value = -math.log((1 - prior_prob) / prior_prob)
|
|
180
|
+
self.class_embed.bias.data = torch.ones(num_classes) * bias_value
|
|
181
|
+
|
|
182
|
+
# init bbox_mebed
|
|
183
|
+
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
|
184
|
+
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
|
185
|
+
|
|
186
|
+
# two_stage
|
|
187
|
+
self.two_stage = two_stage
|
|
188
|
+
if self.two_stage:
|
|
189
|
+
self.transformer.enc_out_bbox_embed = nn.ModuleList(
|
|
190
|
+
[copy.deepcopy(self.bbox_embed) for _ in range(group_detr)]
|
|
191
|
+
)
|
|
192
|
+
self.transformer.enc_out_class_embed = nn.ModuleList(
|
|
193
|
+
[copy.deepcopy(self.class_embed) for _ in range(group_detr)]
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self._export = False
|
|
197
|
+
|
|
198
|
+
def reinitialize_detection_head(self, num_classes):
|
|
199
|
+
base = self.class_embed.weight.shape[0]
|
|
200
|
+
num_repeats = int(math.ceil(num_classes / base))
|
|
201
|
+
self.class_embed.weight.data = self.class_embed.weight.data.repeat(
|
|
202
|
+
num_repeats, 1
|
|
203
|
+
)
|
|
204
|
+
self.class_embed.weight.data = self.class_embed.weight.data[:num_classes]
|
|
205
|
+
self.class_embed.bias.data = self.class_embed.bias.data.repeat(num_repeats)
|
|
206
|
+
self.class_embed.bias.data = self.class_embed.bias.data[:num_classes]
|
|
207
|
+
|
|
208
|
+
if self.two_stage:
|
|
209
|
+
for enc_out_class_embed in self.transformer.enc_out_class_embed:
|
|
210
|
+
enc_out_class_embed.weight.data = (
|
|
211
|
+
enc_out_class_embed.weight.data.repeat(num_repeats, 1)
|
|
212
|
+
)
|
|
213
|
+
enc_out_class_embed.weight.data = enc_out_class_embed.weight.data[
|
|
214
|
+
:num_classes
|
|
215
|
+
]
|
|
216
|
+
enc_out_class_embed.bias.data = enc_out_class_embed.bias.data.repeat(
|
|
217
|
+
num_repeats
|
|
218
|
+
)
|
|
219
|
+
enc_out_class_embed.bias.data = enc_out_class_embed.bias.data[
|
|
220
|
+
:num_classes
|
|
221
|
+
]
|
|
222
|
+
|
|
223
|
+
def export(self):
|
|
224
|
+
self._export = True
|
|
225
|
+
self._forward_origin = self.forward
|
|
226
|
+
self.forward = self.forward_export
|
|
227
|
+
for name, m in self.named_modules():
|
|
228
|
+
if (
|
|
229
|
+
hasattr(m, "export")
|
|
230
|
+
and isinstance(m.export, Callable)
|
|
231
|
+
and hasattr(m, "_export")
|
|
232
|
+
and not m._export
|
|
233
|
+
):
|
|
234
|
+
m.export()
|
|
235
|
+
|
|
236
|
+
def forward(self, samples: NestedTensor, targets=None):
|
|
237
|
+
"""The forward expects a NestedTensor, which consists of:
|
|
238
|
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
|
239
|
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
|
240
|
+
|
|
241
|
+
It returns a dict with the following elements:
|
|
242
|
+
- "pred_logits": the classification logits (including no-object) for all queries.
|
|
243
|
+
Shape= [batch_size x num_queries x num_classes]
|
|
244
|
+
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
|
|
245
|
+
(center_x, center_y, width, height). These values are normalized in [0, 1],
|
|
246
|
+
relative to the size of each individual image (disregarding possible padding).
|
|
247
|
+
See PostProcess for information on how to retrieve the unnormalized bounding box.
|
|
248
|
+
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
|
|
249
|
+
dictionnaries containing the two above keys for each decoder layer.
|
|
250
|
+
"""
|
|
251
|
+
if isinstance(samples, (list, torch.Tensor)):
|
|
252
|
+
samples = nested_tensor_from_tensor_list(samples)
|
|
253
|
+
features, poss = self.backbone(samples)
|
|
254
|
+
|
|
255
|
+
srcs = []
|
|
256
|
+
masks = []
|
|
257
|
+
for l, feat in enumerate(features):
|
|
258
|
+
src, mask = feat.decompose()
|
|
259
|
+
srcs.append(src)
|
|
260
|
+
masks.append(mask)
|
|
261
|
+
assert mask is not None
|
|
262
|
+
|
|
263
|
+
if self.training:
|
|
264
|
+
refpoint_embed_weight = self.refpoint_embed.weight
|
|
265
|
+
query_feat_weight = self.query_feat.weight
|
|
266
|
+
else:
|
|
267
|
+
# only use one group in inference
|
|
268
|
+
refpoint_embed_weight = self.refpoint_embed.weight[: self.num_queries]
|
|
269
|
+
query_feat_weight = self.query_feat.weight[: self.num_queries]
|
|
270
|
+
|
|
271
|
+
hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer(
|
|
272
|
+
srcs, masks, poss, refpoint_embed_weight, query_feat_weight
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if hs is not None:
|
|
276
|
+
if self.bbox_reparam:
|
|
277
|
+
outputs_coord_delta = self.bbox_embed(hs)
|
|
278
|
+
outputs_coord_cxcy = (
|
|
279
|
+
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
|
|
280
|
+
+ ref_unsigmoid[..., :2]
|
|
281
|
+
)
|
|
282
|
+
outputs_coord_wh = (
|
|
283
|
+
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
|
|
284
|
+
)
|
|
285
|
+
outputs_coord = torch.concat(
|
|
286
|
+
[outputs_coord_cxcy, outputs_coord_wh], dim=-1
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
|
|
290
|
+
|
|
291
|
+
outputs_class = self.class_embed(hs)
|
|
292
|
+
|
|
293
|
+
if self.segmentation_head is not None:
|
|
294
|
+
outputs_masks = self.segmentation_head(
|
|
295
|
+
features[0].tensors, hs, samples.tensors.shape[-2:]
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
|
|
299
|
+
if self.segmentation_head is not None:
|
|
300
|
+
out["pred_masks"] = outputs_masks[-1]
|
|
301
|
+
if self.aux_loss:
|
|
302
|
+
out["aux_outputs"] = self._set_aux_loss(
|
|
303
|
+
outputs_class,
|
|
304
|
+
outputs_coord,
|
|
305
|
+
outputs_masks if self.segmentation_head is not None else None,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if self.two_stage:
|
|
309
|
+
group_detr = self.group_detr if self.training else 1
|
|
310
|
+
hs_enc_list = hs_enc.chunk(group_detr, dim=1)
|
|
311
|
+
cls_enc = []
|
|
312
|
+
for g_idx in range(group_detr):
|
|
313
|
+
cls_enc_gidx = self.transformer.enc_out_class_embed[g_idx](
|
|
314
|
+
hs_enc_list[g_idx]
|
|
315
|
+
)
|
|
316
|
+
cls_enc.append(cls_enc_gidx)
|
|
317
|
+
|
|
318
|
+
cls_enc = torch.cat(cls_enc, dim=1)
|
|
319
|
+
|
|
320
|
+
if self.segmentation_head is not None:
|
|
321
|
+
masks_enc = self.segmentation_head(
|
|
322
|
+
features[0].tensors,
|
|
323
|
+
[
|
|
324
|
+
hs_enc,
|
|
325
|
+
],
|
|
326
|
+
samples.tensors.shape[-2:],
|
|
327
|
+
skip_blocks=True,
|
|
328
|
+
)
|
|
329
|
+
masks_enc = torch.cat(masks_enc, dim=1)
|
|
330
|
+
|
|
331
|
+
if hs is not None:
|
|
332
|
+
out["enc_outputs"] = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
|
|
333
|
+
if self.segmentation_head is not None:
|
|
334
|
+
out["enc_outputs"]["pred_masks"] = masks_enc
|
|
335
|
+
else:
|
|
336
|
+
out = {"pred_logits": cls_enc, "pred_boxes": ref_enc}
|
|
337
|
+
if self.segmentation_head is not None:
|
|
338
|
+
out["pred_masks"] = masks_enc
|
|
339
|
+
|
|
340
|
+
return out
|
|
341
|
+
|
|
342
|
+
def forward_export(self, tensors):
|
|
343
|
+
srcs, _, poss = self.backbone(tensors)
|
|
344
|
+
# only use one group in inference
|
|
345
|
+
refpoint_embed_weight = self.refpoint_embed.weight[: self.num_queries]
|
|
346
|
+
query_feat_weight = self.query_feat.weight[: self.num_queries]
|
|
347
|
+
|
|
348
|
+
hs, ref_unsigmoid, hs_enc, ref_enc = self.transformer(
|
|
349
|
+
srcs, None, poss, refpoint_embed_weight, query_feat_weight
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
outputs_masks = None
|
|
353
|
+
|
|
354
|
+
if hs is not None:
|
|
355
|
+
if self.bbox_reparam:
|
|
356
|
+
outputs_coord_delta = self.bbox_embed(hs)
|
|
357
|
+
outputs_coord_cxcy = (
|
|
358
|
+
outputs_coord_delta[..., :2] * ref_unsigmoid[..., 2:]
|
|
359
|
+
+ ref_unsigmoid[..., :2]
|
|
360
|
+
)
|
|
361
|
+
outputs_coord_wh = (
|
|
362
|
+
outputs_coord_delta[..., 2:].exp() * ref_unsigmoid[..., 2:]
|
|
363
|
+
)
|
|
364
|
+
outputs_coord = torch.concat(
|
|
365
|
+
[outputs_coord_cxcy, outputs_coord_wh], dim=-1
|
|
366
|
+
)
|
|
367
|
+
else:
|
|
368
|
+
outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid()
|
|
369
|
+
outputs_class = self.class_embed(hs)
|
|
370
|
+
if self.segmentation_head is not None:
|
|
371
|
+
outputs_masks = self.segmentation_head(
|
|
372
|
+
srcs[0],
|
|
373
|
+
[
|
|
374
|
+
hs,
|
|
375
|
+
],
|
|
376
|
+
tensors.shape[-2:],
|
|
377
|
+
)[0]
|
|
378
|
+
else:
|
|
379
|
+
assert self.two_stage, "if not using decoder, two_stage must be True"
|
|
380
|
+
outputs_class = self.transformer.enc_out_class_embed[0](hs_enc)
|
|
381
|
+
outputs_coord = ref_enc
|
|
382
|
+
if self.segmentation_head is not None:
|
|
383
|
+
outputs_masks = self.segmentation_head(
|
|
384
|
+
srcs[0],
|
|
385
|
+
[
|
|
386
|
+
hs_enc,
|
|
387
|
+
],
|
|
388
|
+
tensors.shape[-2:],
|
|
389
|
+
skip_blocks=True,
|
|
390
|
+
)[0]
|
|
391
|
+
|
|
392
|
+
if outputs_masks is not None:
|
|
393
|
+
return outputs_coord, outputs_class, outputs_masks
|
|
394
|
+
else:
|
|
395
|
+
return outputs_coord, outputs_class
|
|
396
|
+
|
|
397
|
+
@torch.jit.unused
|
|
398
|
+
def _set_aux_loss(self, outputs_class, outputs_coord, outputs_masks):
|
|
399
|
+
# this is a workaround to make torchscript happy, as torchscript
|
|
400
|
+
# doesn't support dictionary with non-homogeneous values, such
|
|
401
|
+
# as a dict having both a Tensor and a list.
|
|
402
|
+
if outputs_masks is not None:
|
|
403
|
+
return [
|
|
404
|
+
{"pred_logits": a, "pred_boxes": b, "pred_masks": c}
|
|
405
|
+
for a, b, c in zip(
|
|
406
|
+
outputs_class[:-1], outputs_coord[:-1], outputs_masks[:-1]
|
|
407
|
+
)
|
|
408
|
+
]
|
|
409
|
+
else:
|
|
410
|
+
return [
|
|
411
|
+
{"pred_logits": a, "pred_boxes": b}
|
|
412
|
+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
|
|
413
|
+
]
|
|
414
|
+
|
|
415
|
+
def update_drop_path(self, drop_path_rate, vit_encoder_num_layers):
|
|
416
|
+
""" """
|
|
417
|
+
dp_rates = [
|
|
418
|
+
x.item() for x in torch.linspace(0, drop_path_rate, vit_encoder_num_layers)
|
|
419
|
+
]
|
|
420
|
+
for i in range(vit_encoder_num_layers):
|
|
421
|
+
if hasattr(self.backbone[0].encoder, "blocks"): # Not aimv2
|
|
422
|
+
if hasattr(self.backbone[0].encoder.blocks[i].drop_path, "drop_prob"):
|
|
423
|
+
self.backbone[0].encoder.blocks[i].drop_path.drop_prob = dp_rates[i]
|
|
424
|
+
else: # aimv2
|
|
425
|
+
if hasattr(
|
|
426
|
+
self.backbone[0].encoder.trunk.blocks[i].drop_path, "drop_prob"
|
|
427
|
+
):
|
|
428
|
+
self.backbone[0].encoder.trunk.blocks[i].drop_path.drop_prob = (
|
|
429
|
+
dp_rates[i]
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
def update_dropout(self, drop_rate):
|
|
433
|
+
for module in self.transformer.modules():
|
|
434
|
+
if isinstance(module, nn.Dropout):
|
|
435
|
+
module.p = drop_rate
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
class MLP(nn.Module):
|
|
439
|
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
|
440
|
+
|
|
441
|
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
442
|
+
super().__init__()
|
|
443
|
+
self.num_layers = num_layers
|
|
444
|
+
h = [hidden_dim] * (num_layers - 1)
|
|
445
|
+
self.layers = nn.ModuleList(
|
|
446
|
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
def forward(self, x):
|
|
450
|
+
for i, layer in enumerate(self.layers):
|
|
451
|
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
452
|
+
return x
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
456
|
+
# TODO make this more general
|
|
457
|
+
if tensor_list[0].ndim == 3:
|
|
458
|
+
if torchvision._is_tracing():
|
|
459
|
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
|
460
|
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
|
461
|
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
|
462
|
+
|
|
463
|
+
# TODO make it support different-sized images
|
|
464
|
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
|
465
|
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
|
466
|
+
batch_shape = [len(tensor_list)] + max_size
|
|
467
|
+
b, c, h, w = batch_shape
|
|
468
|
+
dtype = tensor_list[0].dtype
|
|
469
|
+
device = tensor_list[0].device
|
|
470
|
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
|
471
|
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
|
472
|
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
|
473
|
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
474
|
+
m[: img.shape[1], : img.shape[2]] = False
|
|
475
|
+
else:
|
|
476
|
+
raise ValueError("not supported")
|
|
477
|
+
return NestedTensor(tensor, mask)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
|
481
|
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
|
482
|
+
@torch.jit.unused
|
|
483
|
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
|
484
|
+
max_size = []
|
|
485
|
+
for i in range(tensor_list[0].dim()):
|
|
486
|
+
max_size_i = torch.max(
|
|
487
|
+
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
|
488
|
+
).to(torch.int64)
|
|
489
|
+
max_size.append(max_size_i)
|
|
490
|
+
max_size = tuple(max_size)
|
|
491
|
+
|
|
492
|
+
# work around for
|
|
493
|
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
|
494
|
+
# m[: img.shape[1], :img.shape[2]] = False
|
|
495
|
+
# which is not yet supported in onnx
|
|
496
|
+
padded_imgs = []
|
|
497
|
+
padded_masks = []
|
|
498
|
+
for img in tensor_list:
|
|
499
|
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
|
500
|
+
padded_img = torch.nn.functional.pad(
|
|
501
|
+
img, (0, padding[2], 0, padding[1], 0, padding[0])
|
|
502
|
+
)
|
|
503
|
+
padded_imgs.append(padded_img)
|
|
504
|
+
|
|
505
|
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
|
506
|
+
padded_mask = torch.nn.functional.pad(
|
|
507
|
+
m, (0, padding[2], 0, padding[1]), "constant", 1
|
|
508
|
+
)
|
|
509
|
+
padded_masks.append(padded_mask.to(torch.bool))
|
|
510
|
+
|
|
511
|
+
tensor = torch.stack(padded_imgs)
|
|
512
|
+
mask = torch.stack(padded_masks)
|
|
513
|
+
|
|
514
|
+
return NestedTensor(tensor, mask=mask)
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _max_by_axis(the_list):
|
|
518
|
+
# type: (List[List[int]]) -> List[int]
|
|
519
|
+
maxes = the_list[0]
|
|
520
|
+
for sublist in the_list[1:]:
|
|
521
|
+
for index, item in enumerate(sublist):
|
|
522
|
+
maxes[index] = max(maxes[index], item)
|
|
523
|
+
return maxes
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def build_model(config: ModelConfig) -> LWDETR:
|
|
527
|
+
# the `num_classes` naming here is somewhat misleading.
|
|
528
|
+
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
|
|
529
|
+
# is the maximum id for a class in your dataset. For example,
|
|
530
|
+
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
|
|
531
|
+
# As another example, for a dataset that has a single class with id 1,
|
|
532
|
+
# you should pass `num_classes` to be 2 (max_obj_id + 1).
|
|
533
|
+
# For more details on this, check the following discussion
|
|
534
|
+
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
|
|
535
|
+
args = populate_args(**config.dict())
|
|
536
|
+
num_classes = args.num_classes + 1
|
|
537
|
+
backbone = build_backbone(
|
|
538
|
+
encoder=args.encoder,
|
|
539
|
+
vit_encoder_num_layers=args.vit_encoder_num_layers,
|
|
540
|
+
pretrained_encoder=args.pretrained_encoder,
|
|
541
|
+
window_block_indexes=args.window_block_indexes,
|
|
542
|
+
drop_path=args.drop_path,
|
|
543
|
+
out_channels=args.hidden_dim,
|
|
544
|
+
out_feature_indexes=args.out_feature_indexes,
|
|
545
|
+
projector_scale=args.projector_scale,
|
|
546
|
+
use_cls_token=args.use_cls_token,
|
|
547
|
+
hidden_dim=args.hidden_dim,
|
|
548
|
+
position_embedding=args.position_embedding,
|
|
549
|
+
freeze_encoder=args.freeze_encoder,
|
|
550
|
+
layer_norm=args.layer_norm,
|
|
551
|
+
target_shape=(
|
|
552
|
+
args.shape
|
|
553
|
+
if hasattr(args, "shape")
|
|
554
|
+
else (
|
|
555
|
+
(args.resolution, args.resolution)
|
|
556
|
+
if hasattr(args, "resolution")
|
|
557
|
+
else (640, 640)
|
|
558
|
+
)
|
|
559
|
+
),
|
|
560
|
+
rms_norm=args.rms_norm,
|
|
561
|
+
backbone_lora=args.backbone_lora,
|
|
562
|
+
force_no_pretrain=args.force_no_pretrain,
|
|
563
|
+
gradient_checkpointing=args.gradient_checkpointing,
|
|
564
|
+
load_dinov2_weights=args.pretrain_weights is None,
|
|
565
|
+
patch_size=config.patch_size,
|
|
566
|
+
num_windows=config.num_windows,
|
|
567
|
+
positional_encoding_size=config.positional_encoding_size,
|
|
568
|
+
)
|
|
569
|
+
if args.encoder_only:
|
|
570
|
+
return backbone[0].encoder, None, None
|
|
571
|
+
if args.backbone_only:
|
|
572
|
+
return backbone, None, None
|
|
573
|
+
args.num_feature_levels = len(args.projector_scale)
|
|
574
|
+
transformer = build_transformer(args)
|
|
575
|
+
segmentation_head = (
|
|
576
|
+
SegmentationHead(
|
|
577
|
+
args.hidden_dim,
|
|
578
|
+
args.dec_layers,
|
|
579
|
+
downsample_ratio=args.mask_downsample_ratio,
|
|
580
|
+
)
|
|
581
|
+
if args.segmentation_head
|
|
582
|
+
else None
|
|
583
|
+
)
|
|
584
|
+
return LWDETR(
|
|
585
|
+
backbone,
|
|
586
|
+
transformer,
|
|
587
|
+
segmentation_head,
|
|
588
|
+
num_classes=num_classes,
|
|
589
|
+
num_queries=args.num_queries,
|
|
590
|
+
aux_loss=args.aux_loss,
|
|
591
|
+
group_detr=args.group_detr,
|
|
592
|
+
two_stage=args.two_stage,
|
|
593
|
+
lite_refpoint_refine=args.lite_refpoint_refine,
|
|
594
|
+
bbox_reparam=args.bbox_reparam,
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def populate_args(
|
|
599
|
+
# Basic training parameters
|
|
600
|
+
num_classes=2,
|
|
601
|
+
grad_accum_steps=1,
|
|
602
|
+
amp=False,
|
|
603
|
+
lr=1e-4,
|
|
604
|
+
lr_encoder=1.5e-4,
|
|
605
|
+
batch_size=2,
|
|
606
|
+
weight_decay=1e-4,
|
|
607
|
+
epochs=12,
|
|
608
|
+
lr_drop=11,
|
|
609
|
+
clip_max_norm=0.1,
|
|
610
|
+
lr_vit_layer_decay=0.8,
|
|
611
|
+
lr_component_decay=1.0,
|
|
612
|
+
do_benchmark=False,
|
|
613
|
+
# Drop parameters
|
|
614
|
+
dropout=0,
|
|
615
|
+
drop_path=0,
|
|
616
|
+
drop_mode="standard",
|
|
617
|
+
drop_schedule="constant",
|
|
618
|
+
cutoff_epoch=0,
|
|
619
|
+
# Model parameters
|
|
620
|
+
pretrained_encoder=None,
|
|
621
|
+
pretrain_weights=None,
|
|
622
|
+
pretrain_exclude_keys=None,
|
|
623
|
+
pretrain_keys_modify_to_load=None,
|
|
624
|
+
pretrained_distiller=None,
|
|
625
|
+
# Backbone parameters
|
|
626
|
+
encoder="vit_tiny",
|
|
627
|
+
vit_encoder_num_layers=12,
|
|
628
|
+
window_block_indexes=None,
|
|
629
|
+
position_embedding="sine",
|
|
630
|
+
out_feature_indexes=[-1],
|
|
631
|
+
freeze_encoder=False,
|
|
632
|
+
layer_norm=False,
|
|
633
|
+
rms_norm=False,
|
|
634
|
+
backbone_lora=False,
|
|
635
|
+
force_no_pretrain=False,
|
|
636
|
+
# Transformer parameters
|
|
637
|
+
dec_layers=3,
|
|
638
|
+
dim_feedforward=2048,
|
|
639
|
+
hidden_dim=256,
|
|
640
|
+
sa_nheads=8,
|
|
641
|
+
ca_nheads=8,
|
|
642
|
+
num_queries=300,
|
|
643
|
+
group_detr=13,
|
|
644
|
+
two_stage=False,
|
|
645
|
+
projector_scale="P4",
|
|
646
|
+
lite_refpoint_refine=False,
|
|
647
|
+
num_select=100,
|
|
648
|
+
dec_n_points=4,
|
|
649
|
+
decoder_norm="LN",
|
|
650
|
+
bbox_reparam=False,
|
|
651
|
+
freeze_batch_norm=False,
|
|
652
|
+
# Matcher parameters
|
|
653
|
+
set_cost_class=2,
|
|
654
|
+
set_cost_bbox=5,
|
|
655
|
+
set_cost_giou=2,
|
|
656
|
+
# Loss coefficients
|
|
657
|
+
cls_loss_coef=2,
|
|
658
|
+
bbox_loss_coef=5,
|
|
659
|
+
giou_loss_coef=2,
|
|
660
|
+
focal_alpha=0.25,
|
|
661
|
+
aux_loss=True,
|
|
662
|
+
sum_group_losses=False,
|
|
663
|
+
use_varifocal_loss=False,
|
|
664
|
+
use_position_supervised_loss=False,
|
|
665
|
+
ia_bce_loss=False,
|
|
666
|
+
# Dataset parameters
|
|
667
|
+
dataset_file="coco",
|
|
668
|
+
coco_path=None,
|
|
669
|
+
dataset_dir=None,
|
|
670
|
+
square_resize_div_64=False,
|
|
671
|
+
# Output parameters
|
|
672
|
+
output_dir="output",
|
|
673
|
+
dont_save_weights=False,
|
|
674
|
+
checkpoint_interval=10,
|
|
675
|
+
seed=42,
|
|
676
|
+
resume="",
|
|
677
|
+
start_epoch=0,
|
|
678
|
+
eval=False,
|
|
679
|
+
use_ema=False,
|
|
680
|
+
ema_decay=0.9997,
|
|
681
|
+
ema_tau=0,
|
|
682
|
+
num_workers=2,
|
|
683
|
+
# Distributed training parameters
|
|
684
|
+
device="cuda",
|
|
685
|
+
world_size=1,
|
|
686
|
+
dist_url="env://",
|
|
687
|
+
sync_bn=True,
|
|
688
|
+
# FP16
|
|
689
|
+
fp16_eval=False,
|
|
690
|
+
# Custom args
|
|
691
|
+
encoder_only=False,
|
|
692
|
+
backbone_only=False,
|
|
693
|
+
resolution=640,
|
|
694
|
+
use_cls_token=False,
|
|
695
|
+
multi_scale=False,
|
|
696
|
+
expanded_scales=False,
|
|
697
|
+
warmup_epochs=1,
|
|
698
|
+
lr_scheduler="step",
|
|
699
|
+
lr_min_factor=0.0,
|
|
700
|
+
# Early stopping parameters
|
|
701
|
+
early_stopping=True,
|
|
702
|
+
early_stopping_patience=10,
|
|
703
|
+
early_stopping_min_delta=0.001,
|
|
704
|
+
early_stopping_use_ema=False,
|
|
705
|
+
gradient_checkpointing=False,
|
|
706
|
+
# Additional
|
|
707
|
+
subcommand=None,
|
|
708
|
+
**extra_kwargs, # To handle any unexpected arguments
|
|
709
|
+
):
|
|
710
|
+
args = argparse.Namespace(
|
|
711
|
+
num_classes=num_classes,
|
|
712
|
+
grad_accum_steps=grad_accum_steps,
|
|
713
|
+
amp=amp,
|
|
714
|
+
lr=lr,
|
|
715
|
+
lr_encoder=lr_encoder,
|
|
716
|
+
batch_size=batch_size,
|
|
717
|
+
weight_decay=weight_decay,
|
|
718
|
+
epochs=epochs,
|
|
719
|
+
lr_drop=lr_drop,
|
|
720
|
+
clip_max_norm=clip_max_norm,
|
|
721
|
+
lr_vit_layer_decay=lr_vit_layer_decay,
|
|
722
|
+
lr_component_decay=lr_component_decay,
|
|
723
|
+
do_benchmark=do_benchmark,
|
|
724
|
+
dropout=dropout,
|
|
725
|
+
drop_path=drop_path,
|
|
726
|
+
drop_mode=drop_mode,
|
|
727
|
+
drop_schedule=drop_schedule,
|
|
728
|
+
cutoff_epoch=cutoff_epoch,
|
|
729
|
+
pretrained_encoder=pretrained_encoder,
|
|
730
|
+
pretrain_weights=pretrain_weights,
|
|
731
|
+
pretrain_exclude_keys=pretrain_exclude_keys,
|
|
732
|
+
pretrain_keys_modify_to_load=pretrain_keys_modify_to_load,
|
|
733
|
+
pretrained_distiller=pretrained_distiller,
|
|
734
|
+
encoder=encoder,
|
|
735
|
+
vit_encoder_num_layers=vit_encoder_num_layers,
|
|
736
|
+
window_block_indexes=window_block_indexes,
|
|
737
|
+
position_embedding=position_embedding,
|
|
738
|
+
out_feature_indexes=out_feature_indexes,
|
|
739
|
+
freeze_encoder=freeze_encoder,
|
|
740
|
+
layer_norm=layer_norm,
|
|
741
|
+
rms_norm=rms_norm,
|
|
742
|
+
backbone_lora=backbone_lora,
|
|
743
|
+
force_no_pretrain=force_no_pretrain,
|
|
744
|
+
dec_layers=dec_layers,
|
|
745
|
+
dim_feedforward=dim_feedforward,
|
|
746
|
+
hidden_dim=hidden_dim,
|
|
747
|
+
sa_nheads=sa_nheads,
|
|
748
|
+
ca_nheads=ca_nheads,
|
|
749
|
+
num_queries=num_queries,
|
|
750
|
+
group_detr=group_detr,
|
|
751
|
+
two_stage=two_stage,
|
|
752
|
+
projector_scale=projector_scale,
|
|
753
|
+
lite_refpoint_refine=lite_refpoint_refine,
|
|
754
|
+
num_select=num_select,
|
|
755
|
+
dec_n_points=dec_n_points,
|
|
756
|
+
decoder_norm=decoder_norm,
|
|
757
|
+
bbox_reparam=bbox_reparam,
|
|
758
|
+
freeze_batch_norm=freeze_batch_norm,
|
|
759
|
+
set_cost_class=set_cost_class,
|
|
760
|
+
set_cost_bbox=set_cost_bbox,
|
|
761
|
+
set_cost_giou=set_cost_giou,
|
|
762
|
+
cls_loss_coef=cls_loss_coef,
|
|
763
|
+
bbox_loss_coef=bbox_loss_coef,
|
|
764
|
+
giou_loss_coef=giou_loss_coef,
|
|
765
|
+
focal_alpha=focal_alpha,
|
|
766
|
+
aux_loss=aux_loss,
|
|
767
|
+
sum_group_losses=sum_group_losses,
|
|
768
|
+
use_varifocal_loss=use_varifocal_loss,
|
|
769
|
+
use_position_supervised_loss=use_position_supervised_loss,
|
|
770
|
+
ia_bce_loss=ia_bce_loss,
|
|
771
|
+
dataset_file=dataset_file,
|
|
772
|
+
coco_path=coco_path,
|
|
773
|
+
dataset_dir=dataset_dir,
|
|
774
|
+
square_resize_div_64=square_resize_div_64,
|
|
775
|
+
output_dir=output_dir,
|
|
776
|
+
dont_save_weights=dont_save_weights,
|
|
777
|
+
checkpoint_interval=checkpoint_interval,
|
|
778
|
+
seed=seed,
|
|
779
|
+
resume=resume,
|
|
780
|
+
start_epoch=start_epoch,
|
|
781
|
+
eval=eval,
|
|
782
|
+
use_ema=use_ema,
|
|
783
|
+
ema_decay=ema_decay,
|
|
784
|
+
ema_tau=ema_tau,
|
|
785
|
+
num_workers=num_workers,
|
|
786
|
+
device=device,
|
|
787
|
+
world_size=world_size,
|
|
788
|
+
dist_url=dist_url,
|
|
789
|
+
sync_bn=sync_bn,
|
|
790
|
+
fp16_eval=fp16_eval,
|
|
791
|
+
encoder_only=encoder_only,
|
|
792
|
+
backbone_only=backbone_only,
|
|
793
|
+
resolution=resolution,
|
|
794
|
+
use_cls_token=use_cls_token,
|
|
795
|
+
multi_scale=multi_scale,
|
|
796
|
+
expanded_scales=expanded_scales,
|
|
797
|
+
warmup_epochs=warmup_epochs,
|
|
798
|
+
lr_scheduler=lr_scheduler,
|
|
799
|
+
lr_min_factor=lr_min_factor,
|
|
800
|
+
early_stopping=early_stopping,
|
|
801
|
+
early_stopping_patience=early_stopping_patience,
|
|
802
|
+
early_stopping_min_delta=early_stopping_min_delta,
|
|
803
|
+
early_stopping_use_ema=early_stopping_use_ema,
|
|
804
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
805
|
+
**extra_kwargs,
|
|
806
|
+
)
|
|
807
|
+
return args
|