python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,219 @@
1
+ import logging
2
+ import copy
3
+ import re
4
+ import torch
5
+ import numpy as np
6
+
7
+ def convert_basic_c2_names(original_keys):
8
+ """
9
+ Apply some basic name conversion to names in C2 weights.
10
+ It only deals with typical backbone models.
11
+
12
+ Args:
13
+ original_keys (list[str]):
14
+ Returns:
15
+ list[str]: The same number of strings matching those in original_keys.
16
+ """
17
+ layer_keys = copy.deepcopy(original_keys)
18
+ layer_keys = [
19
+ {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
20
+ ] # some hard-coded mappings
21
+
22
+ layer_keys = [k.replace("_", ".") for k in layer_keys]
23
+ layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
24
+ layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
25
+ # Uniform both bn and gn names to "norm"
26
+ layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
27
+ layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
28
+ layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
29
+ layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
30
+ layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
31
+ layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
32
+ layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
33
+ layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
34
+ layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
35
+ layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
36
+
37
+ # stem
38
+ layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
39
+ # to avoid mis-matching with "conv1" in other components (e.g. detection head)
40
+ layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
41
+
42
+ # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
43
+ # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
44
+ # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
45
+ # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
46
+ # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
47
+
48
+ # blocks
49
+ layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
50
+ layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
51
+ layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
52
+ layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
53
+
54
+ # DensePose substitutions
55
+ layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
56
+ layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
57
+ layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
58
+ layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
59
+ layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
60
+ return layer_keys
61
+
62
+ def convert_c2_detectron_names(weights):
63
+ """
64
+ Map Caffe2 Detectron weight names to Detectron2 names.
65
+
66
+ Args:
67
+ weights (dict): name -> tensor
68
+
69
+ Returns:
70
+ dict: detectron2 names -> tensor
71
+ dict: detectron2 names -> C2 names
72
+ """
73
+ logger = logging.getLogger(__name__)
74
+ logger.info("Renaming Caffe2 weights ......")
75
+ original_keys = sorted(weights.keys())
76
+ layer_keys = copy.deepcopy(original_keys)
77
+
78
+ layer_keys = convert_basic_c2_names(layer_keys)
79
+
80
+ # --------------------------------------------------------------------------
81
+ # RPN hidden representation conv
82
+ # --------------------------------------------------------------------------
83
+ # FPN case
84
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
85
+ # shared for all other levels, hence the appearance of "fpn2"
86
+ layer_keys = [
87
+ k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
88
+ ]
89
+ # Non-FPN case
90
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
91
+
92
+ # --------------------------------------------------------------------------
93
+ # RPN box transformation conv
94
+ # --------------------------------------------------------------------------
95
+ # FPN case (see note above about "fpn2")
96
+ layer_keys = [
97
+ k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
98
+ for k in layer_keys
99
+ ]
100
+ layer_keys = [
101
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
102
+ for k in layer_keys
103
+ ]
104
+ # Non-FPN case
105
+ layer_keys = [
106
+ k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
107
+ ]
108
+ layer_keys = [
109
+ k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
110
+ for k in layer_keys
111
+ ]
112
+
113
+ # --------------------------------------------------------------------------
114
+ # Fast R-CNN box head
115
+ # --------------------------------------------------------------------------
116
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
117
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
118
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
119
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
120
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
121
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
122
+
123
+ # --------------------------------------------------------------------------
124
+ # FPN lateral and output convolutions
125
+ # --------------------------------------------------------------------------
126
+ def fpn_map(name):
127
+ """
128
+ Look for keys with the following patterns:
129
+ 1) Starts with "fpn.inner."
130
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
131
+ Meaning: These are lateral pathway convolutions
132
+ 2) Starts with "fpn.res"
133
+ Example: "fpn.res2.2.sum.weight"
134
+ Meaning: These are FPN output convolutions
135
+ """
136
+ splits = name.split(".")
137
+ norm = ".norm" if "norm" in splits else ""
138
+ if name.startswith("fpn.inner."):
139
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
140
+ stage = int(splits[2][len("res") :])
141
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
142
+ elif name.startswith("fpn.res"):
143
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
144
+ stage = int(splits[1][len("res") :])
145
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
146
+ return name
147
+
148
+ layer_keys = [fpn_map(k) for k in layer_keys]
149
+
150
+ # --------------------------------------------------------------------------
151
+ # Mask R-CNN mask head
152
+ # --------------------------------------------------------------------------
153
+ # roi_heads.StandardROIHeads case
154
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
155
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
156
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
157
+ # roi_heads.Res5ROIHeads case
158
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
159
+
160
+ # --------------------------------------------------------------------------
161
+ # Keypoint R-CNN head
162
+ # --------------------------------------------------------------------------
163
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
164
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
165
+ layer_keys = [
166
+ k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
167
+ ]
168
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
169
+
170
+ # --------------------------------------------------------------------------
171
+ # Done with replacements
172
+ # --------------------------------------------------------------------------
173
+ assert len(set(layer_keys)) == len(layer_keys)
174
+ assert len(original_keys) == len(layer_keys)
175
+
176
+ new_weights = {}
177
+ new_keys_to_original_keys = {}
178
+ for orig, renamed in zip(original_keys, layer_keys):
179
+ new_keys_to_original_keys[renamed] = orig
180
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
181
+ # remove the meaningless prediction weight for background class
182
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
183
+ new_weights[renamed] = weights[orig][new_start_idx:]
184
+ logger.info(
185
+ "Remove prediction weight for background class in {}. The shape changes from "
186
+ "{} to {}.".format(
187
+ renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
188
+ )
189
+ )
190
+ elif renamed.startswith("cls_score."):
191
+ # move weights of bg class from original index 0 to last index
192
+ logger.info(
193
+ "Move classification weights for background class in {} from index 0 to "
194
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
195
+ )
196
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
197
+ else:
198
+ new_weights[renamed] = weights[orig]
199
+
200
+ return new_weights, new_keys_to_original_keys
201
+
202
+ def convert_ndarray_to_tensor(state_dict) -> None:
203
+ """
204
+ In-place convert all numpy arrays in the state_dict to torch tensor.
205
+ Args:
206
+ state_dict (dict): a state-dict to be loaded to the model.
207
+ Will be modified.
208
+ """
209
+ # model could be an OrderedDict with _metadata attribute
210
+ # (as returned by Pytorch's state_dict()). We should preserve these
211
+ # properties.
212
+ for k in list(state_dict.keys()):
213
+ v = state_dict[k]
214
+ if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor):
215
+ raise ValueError(
216
+ "Unsupported type found in checkpoint! {}: {}".format(k, type(v))
217
+ )
218
+ if not isinstance(v, torch.Tensor):
219
+ state_dict[k] = torch.from_numpy(v)
wml/wtorch/nets/fpn.py ADDED
@@ -0,0 +1,276 @@
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import math
3
+ import fvcore.nn.weight_init as weight_init
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from .resnet.r50_config import _C as r50_config
8
+
9
+ from .shape_spec import ShapeSpec
10
+ from wml.wtorch.nn import Conv2d
11
+ from wml.wtorch.nn import get_norm
12
+ from .resnet.resnet import build_resnet_backbone
13
+
14
+ __all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN",'build_resnet_fpn_backbonev2']
15
+
16
+
17
+ class FPN(torch.nn.Module):
18
+ """
19
+ This module implements :paper:`FPN`.
20
+ It creates pyramid features built on top of some input feature maps.
21
+ """
22
+
23
+ _fuse_type: torch.jit.Final[str]
24
+
25
+ def __init__(
26
+ self, bottom_up, in_features, out_channels, norm="", top_block=None, fuse_type="sum"
27
+ ):
28
+ """
29
+ Args:
30
+ bottom_up (Backbone): module representing the bottom up subnetwork.
31
+ Must be a subclass of :class:`Backbone`. The multi-scale feature
32
+ maps generated by the bottom up network, and listed in `in_features`,
33
+ are used to generate FPN levels.
34
+ in_features (list[str]): names of the input feature maps coming
35
+ from the backbone to which FPN is attached. For example, if the
36
+ backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
37
+ of these may be used; order must be from high to low resolution.
38
+ out_channels (int): number of channels in the output feature maps.
39
+ norm (str): the normalization to use.
40
+ top_block (nn.Module or None): if provided, an extra operation will
41
+ be performed on the output of the last (smallest resolution)
42
+ FPN output, and the result will extend the result list. The top_block
43
+ further downsamples the feature map. It must have an attribute
44
+ "num_levels", meaning the number of extra FPN levels added by
45
+ this block, and "in_feature", which is a string representing
46
+ its input feature (e.g., p5).
47
+ fuse_type (str): types for fusing the top down features and the lateral
48
+ ones. It can be "sum" (default), which sums up element-wise; or "avg",
49
+ which takes the element-wise mean of the two.
50
+ """
51
+ super(FPN, self).__init__()
52
+ assert in_features, in_features
53
+
54
+ self.bottom_up = bottom_up
55
+ # Feature map strides and channels from the bottom up network (e.g. ResNet)
56
+ input_shapes = bottom_up.output_shape()
57
+ strides = [input_shapes[f].stride for f in in_features]
58
+ in_channels_per_feature = [input_shapes[f].channels for f in in_features]
59
+
60
+ _assert_strides_are_log2_contiguous(strides)
61
+ lateral_convs = []
62
+ output_convs = []
63
+
64
+ use_bias = norm == ""
65
+ for idx, in_channels in enumerate(in_channels_per_feature):
66
+ lateral_norm = get_norm(norm, out_channels)
67
+ output_norm = get_norm(norm, out_channels)
68
+
69
+ lateral_conv = Conv2d(
70
+ in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm
71
+ )
72
+ output_conv = Conv2d(
73
+ out_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ stride=1,
77
+ padding=1,
78
+ bias=use_bias,
79
+ norm=output_norm,
80
+ )
81
+ weight_init.c2_xavier_fill(lateral_conv)
82
+ weight_init.c2_xavier_fill(output_conv)
83
+ stage = int(math.log2(strides[idx]))
84
+ self.add_module("fpn_lateral{}".format(stage), lateral_conv)
85
+ self.add_module("fpn_output{}".format(stage), output_conv)
86
+
87
+ lateral_convs.append(lateral_conv)
88
+ output_convs.append(output_conv)
89
+ # Place convs into top-down order (from low to high resolution)
90
+ # to make the top-down computation in forward clearer.
91
+ self.lateral_convs = lateral_convs[::-1]
92
+ self.output_convs = output_convs[::-1]
93
+ self.top_block = top_block
94
+ self.in_features = tuple(in_features)
95
+ # Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
96
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
97
+ # top block output feature maps.
98
+ if self.top_block is not None:
99
+ for s in range(stage, stage + self.top_block.num_levels):
100
+ self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
101
+
102
+ self._out_features = list(self._out_feature_strides.keys())
103
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
104
+ self._size_divisibility = strides[-1]
105
+ assert fuse_type in {"avg", "sum"}
106
+ self._fuse_type = fuse_type
107
+
108
+ @property
109
+ def size_divisibility(self):
110
+ return self._size_divisibility
111
+
112
+ def forward(self, x):
113
+ """
114
+ Args:
115
+ input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
116
+ feature map tensor for each feature level in high to low resolution order.
117
+
118
+ Returns:
119
+ dict[str->Tensor]:
120
+ mapping from feature map name to FPN feature map tensor
121
+ in high to low resolution order. Returned feature names follow the FPN
122
+ paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
123
+ ["p2", "p3", ..., "p6"].
124
+ """
125
+ bottom_up_features = self.bottom_up(x)
126
+ results = []
127
+ prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]])
128
+ results.append(self.output_convs[0](prev_features))
129
+
130
+ # Reverse feature maps into top-down order (from low to high resolution)
131
+ for idx, (lateral_conv, output_conv) in enumerate(
132
+ zip(self.lateral_convs, self.output_convs)
133
+ ):
134
+ # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
135
+ # Therefore we loop over all modules but skip the first one
136
+ if idx > 0:
137
+ features = self.in_features[-idx - 1]
138
+ features = bottom_up_features[features]
139
+ top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
140
+ lateral_features = lateral_conv(features)
141
+ prev_features = lateral_features + top_down_features
142
+ if self._fuse_type == "avg":
143
+ prev_features /= 2
144
+ results.insert(0, output_conv(prev_features))
145
+
146
+ if self.top_block is not None:
147
+ if self.top_block.in_feature in bottom_up_features:
148
+ top_block_in_feature = bottom_up_features[self.top_block.in_feature]
149
+ else:
150
+ top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
151
+ results.extend(self.top_block(top_block_in_feature))
152
+ assert len(self._out_features) == len(results)
153
+ return {f: res for f, res in zip(self._out_features, results)}
154
+
155
+ def output_shape(self):
156
+ return {
157
+ name: ShapeSpec(
158
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
159
+ )
160
+ for name in self._out_features
161
+ }
162
+
163
+
164
+ def _assert_strides_are_log2_contiguous(strides):
165
+ """
166
+ Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2".
167
+ """
168
+ for i, stride in enumerate(strides[1:], 1):
169
+ assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format(
170
+ stride, strides[i - 1]
171
+ )
172
+
173
+
174
+ class LastLevelMaxPool(nn.Module):
175
+ """
176
+ This module is used in the original FPN to generate a downsampled
177
+ P6 feature from P5.
178
+ """
179
+
180
+ def __init__(self):
181
+ super().__init__()
182
+ self.num_levels = 1
183
+ self.in_feature = "p5"
184
+
185
+ def forward(self, x):
186
+ return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
187
+
188
+
189
+ class LastLevelP6P7(nn.Module):
190
+ """
191
+ This module is used in RetinaNet to generate extra layers, P6 and P7 from
192
+ C5 feature.
193
+ """
194
+
195
+ def __init__(self, in_channels, out_channels, in_feature="res5"):
196
+ super().__init__()
197
+ self.num_levels = 2
198
+ self.in_feature = in_feature
199
+ self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
200
+ self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
201
+ for module in [self.p6, self.p7]:
202
+ weight_init.c2_xavier_fill(module)
203
+
204
+ def forward(self, c5):
205
+ p6 = self.p6(c5)
206
+ p7 = self.p7(F.relu(p6))
207
+ return [p6, p7]
208
+
209
+
210
+ def build_resnet_fpn_backbone(cfg=r50_config,in_channels=3,in_features=["res2", "res3", "res4", "res5"],
211
+ out_channels=256,norm='BN',
212
+ fuse_type='sum'):
213
+ """
214
+ Args:
215
+ cfg: a detectron2 CfgNode
216
+
217
+ Returns:
218
+ backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
219
+ """
220
+ bottom_up = build_resnet_backbone(cfg, in_channels=in_channels,out_features=in_features)
221
+ in_features = in_features
222
+ out_channels = out_channels
223
+ backbone = FPN(
224
+ bottom_up=bottom_up,
225
+ in_features=in_features,
226
+ out_channels=out_channels,
227
+ norm=norm,
228
+ top_block=None,
229
+ fuse_type=fuse_type
230
+ )
231
+ return backbone
232
+
233
+
234
+ def build_resnet_fpn_backbonev2(bottom_up,in_features=["C2", "C3", "C4", "C5"],
235
+ out_channels=256,norm='BN',
236
+ fuse_type='sum'):
237
+ """
238
+ Args:
239
+ cfg: a detectron2 CfgNode
240
+
241
+ Returns:
242
+ backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
243
+ """
244
+ in_features = in_features
245
+ out_channels = out_channels
246
+ backbone = FPN(
247
+ bottom_up=bottom_up,
248
+ in_features=in_features,
249
+ out_channels=out_channels,
250
+ norm=norm,
251
+ top_block=None,
252
+ fuse_type=fuse_type
253
+ )
254
+ return backbone
255
+
256
+ def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec):
257
+ """
258
+ Args:
259
+ cfg: a detectron2 CfgNode
260
+
261
+ Returns:
262
+ backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
263
+ """
264
+ bottom_up = build_resnet_backbone(cfg, input_shape)
265
+ in_features = cfg.MODEL.FPN.IN_FEATURES
266
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
267
+ in_channels_p6p7 = bottom_up.output_shape()["res5"].channels
268
+ backbone = FPN(
269
+ bottom_up=bottom_up,
270
+ in_features=in_features,
271
+ out_channels=out_channels,
272
+ norm=cfg.MODEL.FPN.NORM,
273
+ top_block=LastLevelP6P7(in_channels_p6p7, out_channels),
274
+ fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
275
+ )
276
+ return backbone
File without changes
@@ -0,0 +1,2 @@
1
+ from yacs.config import CfgNode as CN
2
+ _C = CN(new_allowed=True)