python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import copy
|
|
3
|
+
import re
|
|
4
|
+
import torch
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
def convert_basic_c2_names(original_keys):
|
|
8
|
+
"""
|
|
9
|
+
Apply some basic name conversion to names in C2 weights.
|
|
10
|
+
It only deals with typical backbone models.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
original_keys (list[str]):
|
|
14
|
+
Returns:
|
|
15
|
+
list[str]: The same number of strings matching those in original_keys.
|
|
16
|
+
"""
|
|
17
|
+
layer_keys = copy.deepcopy(original_keys)
|
|
18
|
+
layer_keys = [
|
|
19
|
+
{"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
|
|
20
|
+
] # some hard-coded mappings
|
|
21
|
+
|
|
22
|
+
layer_keys = [k.replace("_", ".") for k in layer_keys]
|
|
23
|
+
layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
|
|
24
|
+
layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
|
|
25
|
+
# Uniform both bn and gn names to "norm"
|
|
26
|
+
layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
|
|
27
|
+
layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
|
|
28
|
+
layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
|
|
29
|
+
layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
|
|
30
|
+
layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
|
|
31
|
+
layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
|
|
32
|
+
layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
|
|
33
|
+
layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
|
|
34
|
+
layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
|
|
35
|
+
layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
|
|
36
|
+
|
|
37
|
+
# stem
|
|
38
|
+
layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
|
|
39
|
+
# to avoid mis-matching with "conv1" in other components (e.g. detection head)
|
|
40
|
+
layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
|
|
41
|
+
|
|
42
|
+
# layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
|
|
43
|
+
# layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
|
|
44
|
+
# layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
|
|
45
|
+
# layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
|
|
46
|
+
# layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
|
|
47
|
+
|
|
48
|
+
# blocks
|
|
49
|
+
layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
|
|
50
|
+
layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
|
|
51
|
+
layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
|
|
52
|
+
layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
|
|
53
|
+
|
|
54
|
+
# DensePose substitutions
|
|
55
|
+
layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
|
|
56
|
+
layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
|
|
57
|
+
layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
|
|
58
|
+
layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
|
|
59
|
+
layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
|
|
60
|
+
return layer_keys
|
|
61
|
+
|
|
62
|
+
def convert_c2_detectron_names(weights):
|
|
63
|
+
"""
|
|
64
|
+
Map Caffe2 Detectron weight names to Detectron2 names.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
weights (dict): name -> tensor
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
dict: detectron2 names -> tensor
|
|
71
|
+
dict: detectron2 names -> C2 names
|
|
72
|
+
"""
|
|
73
|
+
logger = logging.getLogger(__name__)
|
|
74
|
+
logger.info("Renaming Caffe2 weights ......")
|
|
75
|
+
original_keys = sorted(weights.keys())
|
|
76
|
+
layer_keys = copy.deepcopy(original_keys)
|
|
77
|
+
|
|
78
|
+
layer_keys = convert_basic_c2_names(layer_keys)
|
|
79
|
+
|
|
80
|
+
# --------------------------------------------------------------------------
|
|
81
|
+
# RPN hidden representation conv
|
|
82
|
+
# --------------------------------------------------------------------------
|
|
83
|
+
# FPN case
|
|
84
|
+
# In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
|
|
85
|
+
# shared for all other levels, hence the appearance of "fpn2"
|
|
86
|
+
layer_keys = [
|
|
87
|
+
k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
|
|
88
|
+
]
|
|
89
|
+
# Non-FPN case
|
|
90
|
+
layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
|
|
91
|
+
|
|
92
|
+
# --------------------------------------------------------------------------
|
|
93
|
+
# RPN box transformation conv
|
|
94
|
+
# --------------------------------------------------------------------------
|
|
95
|
+
# FPN case (see note above about "fpn2")
|
|
96
|
+
layer_keys = [
|
|
97
|
+
k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
|
|
98
|
+
for k in layer_keys
|
|
99
|
+
]
|
|
100
|
+
layer_keys = [
|
|
101
|
+
k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
|
|
102
|
+
for k in layer_keys
|
|
103
|
+
]
|
|
104
|
+
# Non-FPN case
|
|
105
|
+
layer_keys = [
|
|
106
|
+
k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
|
|
107
|
+
]
|
|
108
|
+
layer_keys = [
|
|
109
|
+
k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
|
|
110
|
+
for k in layer_keys
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
# --------------------------------------------------------------------------
|
|
114
|
+
# Fast R-CNN box head
|
|
115
|
+
# --------------------------------------------------------------------------
|
|
116
|
+
layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
|
|
117
|
+
layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
|
|
118
|
+
layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
|
|
119
|
+
layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
|
|
120
|
+
# 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
|
|
121
|
+
layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
|
|
122
|
+
|
|
123
|
+
# --------------------------------------------------------------------------
|
|
124
|
+
# FPN lateral and output convolutions
|
|
125
|
+
# --------------------------------------------------------------------------
|
|
126
|
+
def fpn_map(name):
|
|
127
|
+
"""
|
|
128
|
+
Look for keys with the following patterns:
|
|
129
|
+
1) Starts with "fpn.inner."
|
|
130
|
+
Example: "fpn.inner.res2.2.sum.lateral.weight"
|
|
131
|
+
Meaning: These are lateral pathway convolutions
|
|
132
|
+
2) Starts with "fpn.res"
|
|
133
|
+
Example: "fpn.res2.2.sum.weight"
|
|
134
|
+
Meaning: These are FPN output convolutions
|
|
135
|
+
"""
|
|
136
|
+
splits = name.split(".")
|
|
137
|
+
norm = ".norm" if "norm" in splits else ""
|
|
138
|
+
if name.startswith("fpn.inner."):
|
|
139
|
+
# splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
|
|
140
|
+
stage = int(splits[2][len("res") :])
|
|
141
|
+
return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
|
|
142
|
+
elif name.startswith("fpn.res"):
|
|
143
|
+
# splits example: ['fpn', 'res2', '2', 'sum', 'weight']
|
|
144
|
+
stage = int(splits[1][len("res") :])
|
|
145
|
+
return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
|
|
146
|
+
return name
|
|
147
|
+
|
|
148
|
+
layer_keys = [fpn_map(k) for k in layer_keys]
|
|
149
|
+
|
|
150
|
+
# --------------------------------------------------------------------------
|
|
151
|
+
# Mask R-CNN mask head
|
|
152
|
+
# --------------------------------------------------------------------------
|
|
153
|
+
# roi_heads.StandardROIHeads case
|
|
154
|
+
layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
|
|
155
|
+
layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
|
|
156
|
+
layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
|
|
157
|
+
# roi_heads.Res5ROIHeads case
|
|
158
|
+
layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
|
|
159
|
+
|
|
160
|
+
# --------------------------------------------------------------------------
|
|
161
|
+
# Keypoint R-CNN head
|
|
162
|
+
# --------------------------------------------------------------------------
|
|
163
|
+
# interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
|
|
164
|
+
layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
|
|
165
|
+
layer_keys = [
|
|
166
|
+
k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
|
|
167
|
+
]
|
|
168
|
+
layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
|
|
169
|
+
|
|
170
|
+
# --------------------------------------------------------------------------
|
|
171
|
+
# Done with replacements
|
|
172
|
+
# --------------------------------------------------------------------------
|
|
173
|
+
assert len(set(layer_keys)) == len(layer_keys)
|
|
174
|
+
assert len(original_keys) == len(layer_keys)
|
|
175
|
+
|
|
176
|
+
new_weights = {}
|
|
177
|
+
new_keys_to_original_keys = {}
|
|
178
|
+
for orig, renamed in zip(original_keys, layer_keys):
|
|
179
|
+
new_keys_to_original_keys[renamed] = orig
|
|
180
|
+
if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
|
|
181
|
+
# remove the meaningless prediction weight for background class
|
|
182
|
+
new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
|
|
183
|
+
new_weights[renamed] = weights[orig][new_start_idx:]
|
|
184
|
+
logger.info(
|
|
185
|
+
"Remove prediction weight for background class in {}. The shape changes from "
|
|
186
|
+
"{} to {}.".format(
|
|
187
|
+
renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
elif renamed.startswith("cls_score."):
|
|
191
|
+
# move weights of bg class from original index 0 to last index
|
|
192
|
+
logger.info(
|
|
193
|
+
"Move classification weights for background class in {} from index 0 to "
|
|
194
|
+
"index {}.".format(renamed, weights[orig].shape[0] - 1)
|
|
195
|
+
)
|
|
196
|
+
new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
|
|
197
|
+
else:
|
|
198
|
+
new_weights[renamed] = weights[orig]
|
|
199
|
+
|
|
200
|
+
return new_weights, new_keys_to_original_keys
|
|
201
|
+
|
|
202
|
+
def convert_ndarray_to_tensor(state_dict) -> None:
|
|
203
|
+
"""
|
|
204
|
+
In-place convert all numpy arrays in the state_dict to torch tensor.
|
|
205
|
+
Args:
|
|
206
|
+
state_dict (dict): a state-dict to be loaded to the model.
|
|
207
|
+
Will be modified.
|
|
208
|
+
"""
|
|
209
|
+
# model could be an OrderedDict with _metadata attribute
|
|
210
|
+
# (as returned by Pytorch's state_dict()). We should preserve these
|
|
211
|
+
# properties.
|
|
212
|
+
for k in list(state_dict.keys()):
|
|
213
|
+
v = state_dict[k]
|
|
214
|
+
if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"Unsupported type found in checkpoint! {}: {}".format(k, type(v))
|
|
217
|
+
)
|
|
218
|
+
if not isinstance(v, torch.Tensor):
|
|
219
|
+
state_dict[k] = torch.from_numpy(v)
|
wml/wtorch/nets/fpn.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
2
|
+
import math
|
|
3
|
+
import fvcore.nn.weight_init as weight_init
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn
|
|
7
|
+
from .resnet.r50_config import _C as r50_config
|
|
8
|
+
|
|
9
|
+
from .shape_spec import ShapeSpec
|
|
10
|
+
from wml.wtorch.nn import Conv2d
|
|
11
|
+
from wml.wtorch.nn import get_norm
|
|
12
|
+
from .resnet.resnet import build_resnet_backbone
|
|
13
|
+
|
|
14
|
+
__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN",'build_resnet_fpn_backbonev2']
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FPN(torch.nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
This module implements :paper:`FPN`.
|
|
20
|
+
It creates pyramid features built on top of some input feature maps.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
_fuse_type: torch.jit.Final[str]
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self, bottom_up, in_features, out_channels, norm="", top_block=None, fuse_type="sum"
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
bottom_up (Backbone): module representing the bottom up subnetwork.
|
|
31
|
+
Must be a subclass of :class:`Backbone`. The multi-scale feature
|
|
32
|
+
maps generated by the bottom up network, and listed in `in_features`,
|
|
33
|
+
are used to generate FPN levels.
|
|
34
|
+
in_features (list[str]): names of the input feature maps coming
|
|
35
|
+
from the backbone to which FPN is attached. For example, if the
|
|
36
|
+
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
|
|
37
|
+
of these may be used; order must be from high to low resolution.
|
|
38
|
+
out_channels (int): number of channels in the output feature maps.
|
|
39
|
+
norm (str): the normalization to use.
|
|
40
|
+
top_block (nn.Module or None): if provided, an extra operation will
|
|
41
|
+
be performed on the output of the last (smallest resolution)
|
|
42
|
+
FPN output, and the result will extend the result list. The top_block
|
|
43
|
+
further downsamples the feature map. It must have an attribute
|
|
44
|
+
"num_levels", meaning the number of extra FPN levels added by
|
|
45
|
+
this block, and "in_feature", which is a string representing
|
|
46
|
+
its input feature (e.g., p5).
|
|
47
|
+
fuse_type (str): types for fusing the top down features and the lateral
|
|
48
|
+
ones. It can be "sum" (default), which sums up element-wise; or "avg",
|
|
49
|
+
which takes the element-wise mean of the two.
|
|
50
|
+
"""
|
|
51
|
+
super(FPN, self).__init__()
|
|
52
|
+
assert in_features, in_features
|
|
53
|
+
|
|
54
|
+
self.bottom_up = bottom_up
|
|
55
|
+
# Feature map strides and channels from the bottom up network (e.g. ResNet)
|
|
56
|
+
input_shapes = bottom_up.output_shape()
|
|
57
|
+
strides = [input_shapes[f].stride for f in in_features]
|
|
58
|
+
in_channels_per_feature = [input_shapes[f].channels for f in in_features]
|
|
59
|
+
|
|
60
|
+
_assert_strides_are_log2_contiguous(strides)
|
|
61
|
+
lateral_convs = []
|
|
62
|
+
output_convs = []
|
|
63
|
+
|
|
64
|
+
use_bias = norm == ""
|
|
65
|
+
for idx, in_channels in enumerate(in_channels_per_feature):
|
|
66
|
+
lateral_norm = get_norm(norm, out_channels)
|
|
67
|
+
output_norm = get_norm(norm, out_channels)
|
|
68
|
+
|
|
69
|
+
lateral_conv = Conv2d(
|
|
70
|
+
in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm
|
|
71
|
+
)
|
|
72
|
+
output_conv = Conv2d(
|
|
73
|
+
out_channels,
|
|
74
|
+
out_channels,
|
|
75
|
+
kernel_size=3,
|
|
76
|
+
stride=1,
|
|
77
|
+
padding=1,
|
|
78
|
+
bias=use_bias,
|
|
79
|
+
norm=output_norm,
|
|
80
|
+
)
|
|
81
|
+
weight_init.c2_xavier_fill(lateral_conv)
|
|
82
|
+
weight_init.c2_xavier_fill(output_conv)
|
|
83
|
+
stage = int(math.log2(strides[idx]))
|
|
84
|
+
self.add_module("fpn_lateral{}".format(stage), lateral_conv)
|
|
85
|
+
self.add_module("fpn_output{}".format(stage), output_conv)
|
|
86
|
+
|
|
87
|
+
lateral_convs.append(lateral_conv)
|
|
88
|
+
output_convs.append(output_conv)
|
|
89
|
+
# Place convs into top-down order (from low to high resolution)
|
|
90
|
+
# to make the top-down computation in forward clearer.
|
|
91
|
+
self.lateral_convs = lateral_convs[::-1]
|
|
92
|
+
self.output_convs = output_convs[::-1]
|
|
93
|
+
self.top_block = top_block
|
|
94
|
+
self.in_features = tuple(in_features)
|
|
95
|
+
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
|
|
96
|
+
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
|
|
97
|
+
# top block output feature maps.
|
|
98
|
+
if self.top_block is not None:
|
|
99
|
+
for s in range(stage, stage + self.top_block.num_levels):
|
|
100
|
+
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
|
|
101
|
+
|
|
102
|
+
self._out_features = list(self._out_feature_strides.keys())
|
|
103
|
+
self._out_feature_channels = {k: out_channels for k in self._out_features}
|
|
104
|
+
self._size_divisibility = strides[-1]
|
|
105
|
+
assert fuse_type in {"avg", "sum"}
|
|
106
|
+
self._fuse_type = fuse_type
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def size_divisibility(self):
|
|
110
|
+
return self._size_divisibility
|
|
111
|
+
|
|
112
|
+
def forward(self, x):
|
|
113
|
+
"""
|
|
114
|
+
Args:
|
|
115
|
+
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
|
|
116
|
+
feature map tensor for each feature level in high to low resolution order.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
dict[str->Tensor]:
|
|
120
|
+
mapping from feature map name to FPN feature map tensor
|
|
121
|
+
in high to low resolution order. Returned feature names follow the FPN
|
|
122
|
+
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
|
|
123
|
+
["p2", "p3", ..., "p6"].
|
|
124
|
+
"""
|
|
125
|
+
bottom_up_features = self.bottom_up(x)
|
|
126
|
+
results = []
|
|
127
|
+
prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]])
|
|
128
|
+
results.append(self.output_convs[0](prev_features))
|
|
129
|
+
|
|
130
|
+
# Reverse feature maps into top-down order (from low to high resolution)
|
|
131
|
+
for idx, (lateral_conv, output_conv) in enumerate(
|
|
132
|
+
zip(self.lateral_convs, self.output_convs)
|
|
133
|
+
):
|
|
134
|
+
# Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
|
|
135
|
+
# Therefore we loop over all modules but skip the first one
|
|
136
|
+
if idx > 0:
|
|
137
|
+
features = self.in_features[-idx - 1]
|
|
138
|
+
features = bottom_up_features[features]
|
|
139
|
+
top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
|
|
140
|
+
lateral_features = lateral_conv(features)
|
|
141
|
+
prev_features = lateral_features + top_down_features
|
|
142
|
+
if self._fuse_type == "avg":
|
|
143
|
+
prev_features /= 2
|
|
144
|
+
results.insert(0, output_conv(prev_features))
|
|
145
|
+
|
|
146
|
+
if self.top_block is not None:
|
|
147
|
+
if self.top_block.in_feature in bottom_up_features:
|
|
148
|
+
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
|
|
149
|
+
else:
|
|
150
|
+
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
|
|
151
|
+
results.extend(self.top_block(top_block_in_feature))
|
|
152
|
+
assert len(self._out_features) == len(results)
|
|
153
|
+
return {f: res for f, res in zip(self._out_features, results)}
|
|
154
|
+
|
|
155
|
+
def output_shape(self):
|
|
156
|
+
return {
|
|
157
|
+
name: ShapeSpec(
|
|
158
|
+
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
|
|
159
|
+
)
|
|
160
|
+
for name in self._out_features
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _assert_strides_are_log2_contiguous(strides):
|
|
165
|
+
"""
|
|
166
|
+
Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2".
|
|
167
|
+
"""
|
|
168
|
+
for i, stride in enumerate(strides[1:], 1):
|
|
169
|
+
assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format(
|
|
170
|
+
stride, strides[i - 1]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class LastLevelMaxPool(nn.Module):
|
|
175
|
+
"""
|
|
176
|
+
This module is used in the original FPN to generate a downsampled
|
|
177
|
+
P6 feature from P5.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(self):
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.num_levels = 1
|
|
183
|
+
self.in_feature = "p5"
|
|
184
|
+
|
|
185
|
+
def forward(self, x):
|
|
186
|
+
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class LastLevelP6P7(nn.Module):
|
|
190
|
+
"""
|
|
191
|
+
This module is used in RetinaNet to generate extra layers, P6 and P7 from
|
|
192
|
+
C5 feature.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(self, in_channels, out_channels, in_feature="res5"):
|
|
196
|
+
super().__init__()
|
|
197
|
+
self.num_levels = 2
|
|
198
|
+
self.in_feature = in_feature
|
|
199
|
+
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
|
|
200
|
+
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
|
|
201
|
+
for module in [self.p6, self.p7]:
|
|
202
|
+
weight_init.c2_xavier_fill(module)
|
|
203
|
+
|
|
204
|
+
def forward(self, c5):
|
|
205
|
+
p6 = self.p6(c5)
|
|
206
|
+
p7 = self.p7(F.relu(p6))
|
|
207
|
+
return [p6, p7]
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def build_resnet_fpn_backbone(cfg=r50_config,in_channels=3,in_features=["res2", "res3", "res4", "res5"],
|
|
211
|
+
out_channels=256,norm='BN',
|
|
212
|
+
fuse_type='sum'):
|
|
213
|
+
"""
|
|
214
|
+
Args:
|
|
215
|
+
cfg: a detectron2 CfgNode
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
|
219
|
+
"""
|
|
220
|
+
bottom_up = build_resnet_backbone(cfg, in_channels=in_channels,out_features=in_features)
|
|
221
|
+
in_features = in_features
|
|
222
|
+
out_channels = out_channels
|
|
223
|
+
backbone = FPN(
|
|
224
|
+
bottom_up=bottom_up,
|
|
225
|
+
in_features=in_features,
|
|
226
|
+
out_channels=out_channels,
|
|
227
|
+
norm=norm,
|
|
228
|
+
top_block=None,
|
|
229
|
+
fuse_type=fuse_type
|
|
230
|
+
)
|
|
231
|
+
return backbone
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def build_resnet_fpn_backbonev2(bottom_up,in_features=["C2", "C3", "C4", "C5"],
|
|
235
|
+
out_channels=256,norm='BN',
|
|
236
|
+
fuse_type='sum'):
|
|
237
|
+
"""
|
|
238
|
+
Args:
|
|
239
|
+
cfg: a detectron2 CfgNode
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
|
243
|
+
"""
|
|
244
|
+
in_features = in_features
|
|
245
|
+
out_channels = out_channels
|
|
246
|
+
backbone = FPN(
|
|
247
|
+
bottom_up=bottom_up,
|
|
248
|
+
in_features=in_features,
|
|
249
|
+
out_channels=out_channels,
|
|
250
|
+
norm=norm,
|
|
251
|
+
top_block=None,
|
|
252
|
+
fuse_type=fuse_type
|
|
253
|
+
)
|
|
254
|
+
return backbone
|
|
255
|
+
|
|
256
|
+
def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
|
|
257
|
+
"""
|
|
258
|
+
Args:
|
|
259
|
+
cfg: a detectron2 CfgNode
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
|
|
263
|
+
"""
|
|
264
|
+
bottom_up = build_resnet_backbone(cfg, input_shape)
|
|
265
|
+
in_features = cfg.MODEL.FPN.IN_FEATURES
|
|
266
|
+
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
|
|
267
|
+
in_channels_p6p7 = bottom_up.output_shape()["res5"].channels
|
|
268
|
+
backbone = FPN(
|
|
269
|
+
bottom_up=bottom_up,
|
|
270
|
+
in_features=in_features,
|
|
271
|
+
out_channels=out_channels,
|
|
272
|
+
norm=cfg.MODEL.FPN.NORM,
|
|
273
|
+
top_block=LastLevelP6P7(in_channels_p6p7, out_channels),
|
|
274
|
+
fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
|
|
275
|
+
)
|
|
276
|
+
return backbone
|
|
File without changes
|