python-wml 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python-wml might be problematic. Click here for more details.
- python_wml-3.0.0.dist-info/LICENSE +23 -0
- python_wml-3.0.0.dist-info/METADATA +51 -0
- python_wml-3.0.0.dist-info/RECORD +164 -0
- python_wml-3.0.0.dist-info/WHEEL +5 -0
- python_wml-3.0.0.dist-info/top_level.txt +1 -0
- wml/__init__.py +0 -0
- wml/basic_data_def/__init__.py +2 -0
- wml/basic_data_def/detection_data_def.py +279 -0
- wml/basic_data_def/io_data_def.py +2 -0
- wml/basic_img_utils.py +816 -0
- wml/img_patch.py +92 -0
- wml/img_utils.py +571 -0
- wml/iotoolkit/__init__.py +17 -0
- wml/iotoolkit/aic_keypoint.py +115 -0
- wml/iotoolkit/baidu_mask_toolkit.py +244 -0
- wml/iotoolkit/base_dataset.py +210 -0
- wml/iotoolkit/bboxes_statistics.py +515 -0
- wml/iotoolkit/build.py +0 -0
- wml/iotoolkit/cityscapes_toolkit.py +183 -0
- wml/iotoolkit/classification_data_statistics.py +25 -0
- wml/iotoolkit/coco_data_fwd.py +225 -0
- wml/iotoolkit/coco_keypoints.py +118 -0
- wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
- wml/iotoolkit/coco_toolkit.py +397 -0
- wml/iotoolkit/coco_wholebody.py +269 -0
- wml/iotoolkit/common.py +108 -0
- wml/iotoolkit/crowd_pose.py +146 -0
- wml/iotoolkit/fast_labelme.py +110 -0
- wml/iotoolkit/image_folder.py +95 -0
- wml/iotoolkit/imgs_cache.py +58 -0
- wml/iotoolkit/imgs_reader_mt.py +73 -0
- wml/iotoolkit/labelme_base.py +102 -0
- wml/iotoolkit/labelme_json_to_img.py +49 -0
- wml/iotoolkit/labelme_toolkit.py +117 -0
- wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
- wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
- wml/iotoolkit/lspet.py +48 -0
- wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
- wml/iotoolkit/mat_data.py +90 -0
- wml/iotoolkit/mckeypoints_statistics.py +28 -0
- wml/iotoolkit/mot_datasets.py +62 -0
- wml/iotoolkit/mpii.py +108 -0
- wml/iotoolkit/npmckeypoints_dataset.py +164 -0
- wml/iotoolkit/o365_to_coco.py +136 -0
- wml/iotoolkit/object365_toolkit.py +156 -0
- wml/iotoolkit/object365v2_toolkit.py +71 -0
- wml/iotoolkit/pascal_voc_data.py +51 -0
- wml/iotoolkit/pascal_voc_toolkit.py +194 -0
- wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
- wml/iotoolkit/penn_action.py +57 -0
- wml/iotoolkit/rawframe_dataset.py +129 -0
- wml/iotoolkit/rewrite_pascal_voc.py +28 -0
- wml/iotoolkit/semantic_data.py +49 -0
- wml/iotoolkit/split_file_by_type.py +29 -0
- wml/iotoolkit/sports_mot_datasets.py +78 -0
- wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
- wml/iotoolkit/vis_torch_data.py +39 -0
- wml/iotoolkit/yolo_toolkit.py +38 -0
- wml/object_detection2/__init__.py +4 -0
- wml/object_detection2/basic_visualization.py +37 -0
- wml/object_detection2/bboxes.py +812 -0
- wml/object_detection2/data_process_toolkit.py +146 -0
- wml/object_detection2/keypoints.py +292 -0
- wml/object_detection2/mask.py +120 -0
- wml/object_detection2/metrics/__init__.py +3 -0
- wml/object_detection2/metrics/build.py +15 -0
- wml/object_detection2/metrics/classifier_toolkit.py +440 -0
- wml/object_detection2/metrics/common.py +71 -0
- wml/object_detection2/metrics/mckps_toolkit.py +338 -0
- wml/object_detection2/metrics/toolkit.py +1953 -0
- wml/object_detection2/npod_toolkit.py +361 -0
- wml/object_detection2/odtools.py +243 -0
- wml/object_detection2/standard_names.py +75 -0
- wml/object_detection2/visualization.py +956 -0
- wml/object_detection2/wmath.py +34 -0
- wml/semantic/__init__.py +0 -0
- wml/semantic/basic_toolkit.py +65 -0
- wml/semantic/mask_utils.py +156 -0
- wml/semantic/semantic_test.py +21 -0
- wml/semantic/structures.py +1 -0
- wml/semantic/toolkit.py +105 -0
- wml/semantic/visualization_utils.py +658 -0
- wml/threadtoolkit.py +50 -0
- wml/walgorithm.py +228 -0
- wml/wcollections.py +212 -0
- wml/wfilesystem.py +487 -0
- wml/wml_utils.py +657 -0
- wml/wstructures/__init__.py +4 -0
- wml/wstructures/common.py +9 -0
- wml/wstructures/keypoints_train_toolkit.py +149 -0
- wml/wstructures/kps_structures.py +579 -0
- wml/wstructures/mask_structures.py +1161 -0
- wml/wtorch/__init__.py +8 -0
- wml/wtorch/bboxes.py +104 -0
- wml/wtorch/classes_suppression.py +24 -0
- wml/wtorch/conv_module.py +181 -0
- wml/wtorch/conv_ws.py +144 -0
- wml/wtorch/data/__init__.py +16 -0
- wml/wtorch/data/_utils/__init__.py +45 -0
- wml/wtorch/data/_utils/collate.py +183 -0
- wml/wtorch/data/_utils/fetch.py +47 -0
- wml/wtorch/data/_utils/pin_memory.py +121 -0
- wml/wtorch/data/_utils/signal_handling.py +72 -0
- wml/wtorch/data/_utils/worker.py +227 -0
- wml/wtorch/data/base_data_loader_iter.py +93 -0
- wml/wtorch/data/dataloader.py +501 -0
- wml/wtorch/data/datapipes/__init__.py +1 -0
- wml/wtorch/data/datapipes/iter/__init__.py +12 -0
- wml/wtorch/data/datapipes/iter/batch.py +126 -0
- wml/wtorch/data/datapipes/iter/callable.py +92 -0
- wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
- wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
- wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
- wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
- wml/wtorch/data/datapipes/iter/sampler.py +94 -0
- wml/wtorch/data/datapipes/utils/__init__.py +0 -0
- wml/wtorch/data/datapipes/utils/common.py +65 -0
- wml/wtorch/data/dataset.py +354 -0
- wml/wtorch/data/datasets/__init__.py +4 -0
- wml/wtorch/data/datasets/common.py +53 -0
- wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
- wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
- wml/wtorch/data/distributed.py +135 -0
- wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
- wml/wtorch/data/sampler.py +267 -0
- wml/wtorch/data/single_process_data_loader_iter.py +24 -0
- wml/wtorch/data/test_data_loader.py +26 -0
- wml/wtorch/dataset_toolkit.py +67 -0
- wml/wtorch/depthwise_separable_conv_module.py +98 -0
- wml/wtorch/dist.py +591 -0
- wml/wtorch/dropblock/__init__.py +6 -0
- wml/wtorch/dropblock/dropblock.py +228 -0
- wml/wtorch/dropblock/dropout.py +40 -0
- wml/wtorch/dropblock/scheduler.py +48 -0
- wml/wtorch/ema.py +61 -0
- wml/wtorch/fc_module.py +73 -0
- wml/wtorch/functional.py +34 -0
- wml/wtorch/iter_dataset.py +26 -0
- wml/wtorch/loss.py +69 -0
- wml/wtorch/nets/__init__.py +0 -0
- wml/wtorch/nets/ckpt_toolkit.py +219 -0
- wml/wtorch/nets/fpn.py +276 -0
- wml/wtorch/nets/hrnet/__init__.py +0 -0
- wml/wtorch/nets/hrnet/config.py +2 -0
- wml/wtorch/nets/hrnet/hrnet.py +494 -0
- wml/wtorch/nets/misc.py +249 -0
- wml/wtorch/nets/resnet/__init__.py +0 -0
- wml/wtorch/nets/resnet/layers/__init__.py +17 -0
- wml/wtorch/nets/resnet/layers/aspp.py +144 -0
- wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
- wml/wtorch/nets/resnet/layers/blocks.py +111 -0
- wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
- wml/wtorch/nets/resnet/r50_config.py +38 -0
- wml/wtorch/nets/resnet/resnet.py +691 -0
- wml/wtorch/nets/shape_spec.py +20 -0
- wml/wtorch/nets/simple_fpn.py +101 -0
- wml/wtorch/nms.py +109 -0
- wml/wtorch/nn.py +896 -0
- wml/wtorch/ocr_block.py +193 -0
- wml/wtorch/summary.py +331 -0
- wml/wtorch/train_toolkit.py +603 -0
- wml/wtorch/transformer_blocks.py +266 -0
- wml/wtorch/utils.py +719 -0
- wml/wtorch/wlr_scheduler.py +100 -0
wml/wtorch/ocr_block.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
def Norm2d(in_channels, **kwargs):
|
|
6
|
+
"""
|
|
7
|
+
Custom Norm Function to allow flexible switching
|
|
8
|
+
"""
|
|
9
|
+
layer = nn.BatchNorm2d
|
|
10
|
+
normalization_layer = layer(in_channels, **kwargs)
|
|
11
|
+
return normalization_layer
|
|
12
|
+
|
|
13
|
+
def BNReLU(ch):
|
|
14
|
+
return nn.Sequential(
|
|
15
|
+
Norm2d(ch),
|
|
16
|
+
nn.ReLU())
|
|
17
|
+
|
|
18
|
+
class SpatialGather_Module(nn.Module):
|
|
19
|
+
"""
|
|
20
|
+
Aggregate the context features according to the initial
|
|
21
|
+
predicted probability distribution.
|
|
22
|
+
Employ the soft-weighted method to aggregate the context.
|
|
23
|
+
|
|
24
|
+
Output:
|
|
25
|
+
The correlation of every class map with every feature map
|
|
26
|
+
shape = [n, num_feats, num_classes, 1]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, cls_num=0, scale=1):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.cls_num = cls_num
|
|
33
|
+
self.scale = scale
|
|
34
|
+
|
|
35
|
+
def forward(self, feats, probs):
|
|
36
|
+
batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(2), \
|
|
37
|
+
probs.size(3)
|
|
38
|
+
|
|
39
|
+
# each class image now a vector
|
|
40
|
+
probs = probs.view(batch_size, c, -1)
|
|
41
|
+
feats = feats.view(batch_size, feats.size(1), -1)
|
|
42
|
+
|
|
43
|
+
feats = feats.permute(0, 2, 1) # batch x hw x c
|
|
44
|
+
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
|
45
|
+
ocr_context = torch.matmul(probs, feats)
|
|
46
|
+
ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3)
|
|
47
|
+
return ocr_context
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ObjectAttentionBlock(nn.Module):
|
|
51
|
+
'''
|
|
52
|
+
The basic implementation for object context block
|
|
53
|
+
Input:
|
|
54
|
+
N X C X H X W
|
|
55
|
+
Parameters:
|
|
56
|
+
in_channels : the dimension of the input feature map
|
|
57
|
+
key_channels : the dimension after the key/query transform
|
|
58
|
+
scale : choose the scale to downsample the input feature
|
|
59
|
+
maps (save memory cost)
|
|
60
|
+
Return:
|
|
61
|
+
N X C X H X W
|
|
62
|
+
'''
|
|
63
|
+
def __init__(self, in_channels, key_channels, scale=1):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.scale = scale
|
|
66
|
+
self.in_channels = in_channels
|
|
67
|
+
self.key_channels = key_channels
|
|
68
|
+
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
|
69
|
+
self.f_pixel = nn.Sequential(
|
|
70
|
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
|
71
|
+
kernel_size=1, stride=1, padding=0, bias=False),
|
|
72
|
+
BNReLU(self.key_channels),
|
|
73
|
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
|
74
|
+
kernel_size=1, stride=1, padding=0, bias=False),
|
|
75
|
+
BNReLU(self.key_channels),
|
|
76
|
+
)
|
|
77
|
+
self.f_object = nn.Sequential(
|
|
78
|
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
|
79
|
+
kernel_size=1, stride=1, padding=0, bias=False),
|
|
80
|
+
BNReLU(self.key_channels),
|
|
81
|
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
|
|
82
|
+
kernel_size=1, stride=1, padding=0, bias=False),
|
|
83
|
+
BNReLU(self.key_channels),
|
|
84
|
+
)
|
|
85
|
+
self.f_down = nn.Sequential(
|
|
86
|
+
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
|
|
87
|
+
kernel_size=1, stride=1, padding=0, bias=False),
|
|
88
|
+
BNReLU(self.key_channels),
|
|
89
|
+
)
|
|
90
|
+
self.f_up = nn.Sequential(
|
|
91
|
+
nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
|
|
92
|
+
kernel_size=1, stride=1, padding=0, bias=False),
|
|
93
|
+
BNReLU(self.in_channels),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def forward(self, x, proxy):
|
|
97
|
+
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
|
98
|
+
if self.scale > 1:
|
|
99
|
+
x = self.pool(x)
|
|
100
|
+
|
|
101
|
+
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
|
102
|
+
query = query.permute(0, 2, 1)
|
|
103
|
+
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
|
104
|
+
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
|
105
|
+
value = value.permute(0, 2, 1)
|
|
106
|
+
|
|
107
|
+
sim_map = torch.matmul(query, key)
|
|
108
|
+
sim_map = (self.key_channels**-.5) * sim_map
|
|
109
|
+
sim_map = F.softmax(sim_map, dim=-1)
|
|
110
|
+
|
|
111
|
+
# add bg context ...
|
|
112
|
+
context = torch.matmul(sim_map, value)
|
|
113
|
+
context = context.permute(0, 2, 1).contiguous()
|
|
114
|
+
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
|
115
|
+
context = self.f_up(context)
|
|
116
|
+
if self.scale > 1:
|
|
117
|
+
context = F.interpolate(input=context, size=(h, w), mode='bilinear',
|
|
118
|
+
align_corners=False)
|
|
119
|
+
|
|
120
|
+
return context
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SpatialOCR_Module(nn.Module):
|
|
124
|
+
"""
|
|
125
|
+
Implementation of the OCR module:
|
|
126
|
+
We aggregate the global object representation to update the representation
|
|
127
|
+
for each pixel.
|
|
128
|
+
"""
|
|
129
|
+
def __init__(self, in_channels, key_channels, out_channels, scale=1,
|
|
130
|
+
dropout=0.1):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.object_context_block = ObjectAttentionBlock(in_channels,
|
|
133
|
+
key_channels,
|
|
134
|
+
scale)
|
|
135
|
+
_in_channels = 2 * in_channels
|
|
136
|
+
|
|
137
|
+
self.conv_bn_dropout = nn.Sequential(
|
|
138
|
+
nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0,
|
|
139
|
+
bias=False),
|
|
140
|
+
BNReLU(out_channels),
|
|
141
|
+
nn.Dropout2d(dropout)
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def forward(self, feats, proxy_feats):
|
|
145
|
+
context = self.object_context_block(feats, proxy_feats)
|
|
146
|
+
|
|
147
|
+
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
|
148
|
+
|
|
149
|
+
return output
|
|
150
|
+
|
|
151
|
+
class OCRBlock(nn.Module):
|
|
152
|
+
"""
|
|
153
|
+
Some of the code in this class is borrowed from:
|
|
154
|
+
https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/HRNet-OCR
|
|
155
|
+
"""
|
|
156
|
+
def __init__(self, in_channels,num_classes,key_channels=256,mid_channel=512):
|
|
157
|
+
super().__init__()
|
|
158
|
+
|
|
159
|
+
ocr_mid_channels = mid_channel
|
|
160
|
+
ocr_key_channels = key_channels
|
|
161
|
+
num_classes = num_classes
|
|
162
|
+
|
|
163
|
+
self.conv3x3_ocr = nn.Sequential(
|
|
164
|
+
nn.Conv2d(in_channels, ocr_mid_channels,
|
|
165
|
+
kernel_size=3, stride=1, padding=1),
|
|
166
|
+
BNReLU(ocr_mid_channels),
|
|
167
|
+
)
|
|
168
|
+
self.ocr_gather_head = SpatialGather_Module(num_classes)
|
|
169
|
+
self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
|
|
170
|
+
key_channels=ocr_key_channels,
|
|
171
|
+
out_channels=ocr_mid_channels,
|
|
172
|
+
scale=1,
|
|
173
|
+
dropout=0.05,
|
|
174
|
+
)
|
|
175
|
+
self.cls_head = nn.Conv2d(
|
|
176
|
+
ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0,
|
|
177
|
+
bias=True)
|
|
178
|
+
|
|
179
|
+
self.aux_head = nn.Sequential(
|
|
180
|
+
nn.Conv2d(in_channels, in_channels,
|
|
181
|
+
kernel_size=1, stride=1, padding=0),
|
|
182
|
+
BNReLU(in_channels),
|
|
183
|
+
nn.Conv2d(in_channels, num_classes,
|
|
184
|
+
kernel_size=1, stride=1, padding=0, bias=True)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def forward(self, high_level_features):
|
|
188
|
+
feats = self.conv3x3_ocr(high_level_features)
|
|
189
|
+
aux_out = self.aux_head(high_level_features)
|
|
190
|
+
context = self.ocr_gather_head(feats, aux_out)
|
|
191
|
+
ocr_feats = self.ocr_distri_head(feats, context)
|
|
192
|
+
cls_out = self.cls_head(ocr_feats)
|
|
193
|
+
return cls_out, aux_out, ocr_feats
|
wml/wtorch/summary.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from collections import Iterable
|
|
3
|
+
import wml.object_detection2.visualization as odv
|
|
4
|
+
import random
|
|
5
|
+
import numpy as np
|
|
6
|
+
import cv2
|
|
7
|
+
import wml.object_detection2.bboxes as odb
|
|
8
|
+
import wml.basic_img_utils as bwmli
|
|
9
|
+
#from torch.utils.tensorboard import SummaryWriter
|
|
10
|
+
#SummaryWriter.add_image()
|
|
11
|
+
#tb.add_images
|
|
12
|
+
|
|
13
|
+
def _draw_text_on_image(img,text,font_scale=1.2,color=(0.,255.,0.),pos=None):
|
|
14
|
+
if isinstance(text,bytes):
|
|
15
|
+
text = str(text,encoding="utf-8")
|
|
16
|
+
if not isinstance(text,str):
|
|
17
|
+
text = str(text)
|
|
18
|
+
thickness = 2
|
|
19
|
+
size = cv2.getTextSize(text,fontFace=cv2.FONT_HERSHEY_COMPLEX,fontScale=font_scale,thickness=thickness)
|
|
20
|
+
if pos is None:
|
|
21
|
+
pos = (0,(img.shape[0]+size[0][1])//2)
|
|
22
|
+
cv2.putText(img, text, pos, cv2.FONT_HERSHEY_COMPLEX, fontScale=font_scale, color=color, thickness=thickness)
|
|
23
|
+
return img
|
|
24
|
+
|
|
25
|
+
def log_all_variable(tb,net:torch.nn.Module,global_step):
|
|
26
|
+
try:
|
|
27
|
+
for name,param in net.named_parameters():
|
|
28
|
+
if 'bias' in name:
|
|
29
|
+
name = "BIAS/"+name
|
|
30
|
+
elif "." in name:
|
|
31
|
+
name = name.replace(".","/",1)
|
|
32
|
+
if param.numel()>1:
|
|
33
|
+
tb.add_histogram(name,param,global_step)
|
|
34
|
+
else:
|
|
35
|
+
tb.add_scalar(name,param,global_step)
|
|
36
|
+
|
|
37
|
+
data = net.state_dict()
|
|
38
|
+
for name in data:
|
|
39
|
+
if "running" in name:
|
|
40
|
+
param = data[name]
|
|
41
|
+
if param.numel()>1:
|
|
42
|
+
tb.add_histogram("BN/"+name,param,global_step)
|
|
43
|
+
else:
|
|
44
|
+
tb.add_scalar("BN/"+name,param,global_step)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
print("ERROR:",e)
|
|
47
|
+
|
|
48
|
+
def log_all_variable_min_max(tb,net:torch.nn.Module,global_step):
|
|
49
|
+
try:
|
|
50
|
+
for name,param in net.named_parameters():
|
|
51
|
+
if 'bias' in name:
|
|
52
|
+
name = "BIAS/"+name
|
|
53
|
+
elif "." in name:
|
|
54
|
+
name = name.replace(".","/",1)
|
|
55
|
+
if param.numel()>1:
|
|
56
|
+
std,mean = torch.std_mean(param)
|
|
57
|
+
tb.add_scalar(name+"_min",torch.min(param),global_step)
|
|
58
|
+
tb.add_scalar(name+"_max",torch.max(param),global_step)
|
|
59
|
+
tb.add_scalar(name+"_mean",mean,global_step)
|
|
60
|
+
tb.add_scalar(name+"_std",std,global_step)
|
|
61
|
+
else:
|
|
62
|
+
tb.add_scalar(name,param,global_step)
|
|
63
|
+
|
|
64
|
+
data = net.state_dict()
|
|
65
|
+
for name in data:
|
|
66
|
+
if "running" in name:
|
|
67
|
+
param = data[name]
|
|
68
|
+
if param.numel()>1:
|
|
69
|
+
tb.add_histogram("BN/"+name,param,global_step)
|
|
70
|
+
else:
|
|
71
|
+
tb.add_scalar("BN/"+name,param,global_step)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
print("ERROR:",e)
|
|
74
|
+
|
|
75
|
+
def log_basic_info(tb,name,value:torch.Tensor,global_step):
|
|
76
|
+
if value.numel()>1:
|
|
77
|
+
min_v = torch.min(value)
|
|
78
|
+
max_v = torch.max(value)
|
|
79
|
+
mean_v = torch.mean(value)
|
|
80
|
+
std_v = torch.std(value)
|
|
81
|
+
tb.add_scalar(name+"/min",min_v,global_step)
|
|
82
|
+
tb.add_scalar(name+"/max",max_v,global_step)
|
|
83
|
+
tb.add_scalar(name+"/mean",mean_v,global_step)
|
|
84
|
+
tb.add_scalar(name+"/std",std_v,global_step)
|
|
85
|
+
else:
|
|
86
|
+
tb.add_scalar(name,value,global_step)
|
|
87
|
+
|
|
88
|
+
def add_image_with_label(tb,name,image,label,global_step):
|
|
89
|
+
label = str(label)
|
|
90
|
+
image = image.numpy()
|
|
91
|
+
image = image.transpose(1,2,0)
|
|
92
|
+
image = _draw_text_on_image(image,label)
|
|
93
|
+
image = image.transpose(2,0,1)
|
|
94
|
+
tb.add_image(name,image,global_step)
|
|
95
|
+
|
|
96
|
+
def add_images_with_label(tb,name,image,label,global_step,font_scale=1.2):
|
|
97
|
+
if isinstance(image,torch.Tensor):
|
|
98
|
+
image = image.numpy()
|
|
99
|
+
image = image.transpose(0,2,3,1)
|
|
100
|
+
image = np.ascontiguousarray(image)
|
|
101
|
+
if not isinstance(label,Iterable):
|
|
102
|
+
label = str(label)
|
|
103
|
+
image[0] = _draw_text_on_image(image[0], label,font_scale=font_scale)
|
|
104
|
+
elif len(label) == 1:
|
|
105
|
+
label = str(label[0])
|
|
106
|
+
image[0] = _draw_text_on_image(image[0],label,font_scale=font_scale)
|
|
107
|
+
elif len(label) == image.shape[0]:
|
|
108
|
+
for i in range(len(label)):
|
|
109
|
+
image[i] = _draw_text_on_image(image[i], str(label[i]),font_scale=font_scale)
|
|
110
|
+
else:
|
|
111
|
+
print(f"ERROR label {label}")
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
image = image.transpose(0,3,1,2)
|
|
115
|
+
tb.add_images(name,image,global_step)
|
|
116
|
+
|
|
117
|
+
def log_feature_map(tb,name,tensor,global_step,random_index=True):
|
|
118
|
+
'''
|
|
119
|
+
tensor: [B,C,H,W]
|
|
120
|
+
'''
|
|
121
|
+
if isinstance(tensor,torch.Tensor):
|
|
122
|
+
tensor = tensor.cpu().detach().numpy()
|
|
123
|
+
|
|
124
|
+
if random_index:
|
|
125
|
+
i = random.randint(0,tensor.shape[0]-1)
|
|
126
|
+
else:
|
|
127
|
+
i = 0
|
|
128
|
+
data = tensor[i]
|
|
129
|
+
data = np.expand_dims(data,axis=1)
|
|
130
|
+
min = np.min(data)
|
|
131
|
+
max = np.max(data)
|
|
132
|
+
data = (data-min)/(max-min+1e-8)
|
|
133
|
+
tb.add_images(name,data,global_step)
|
|
134
|
+
|
|
135
|
+
def try_log_rgb_feature_map(tb,name,tensor,global_step,random_index=True,min_upper_bounder=None,max_lower_bounder=None):
|
|
136
|
+
if isinstance(tensor,torch.Tensor):
|
|
137
|
+
tensor = tensor.cpu().detach().numpy()
|
|
138
|
+
|
|
139
|
+
if random_index:
|
|
140
|
+
i = random.randint(0,tensor.shape[0]-1)
|
|
141
|
+
else:
|
|
142
|
+
i = 0
|
|
143
|
+
C = tensor.shape[1]
|
|
144
|
+
data = tensor[i]
|
|
145
|
+
min = np.min(data)
|
|
146
|
+
if min_upper_bounder is not None:
|
|
147
|
+
min = np.minimum(min,min_upper_bounder)
|
|
148
|
+
max = np.max(data)
|
|
149
|
+
if max_lower_bounder is not None:
|
|
150
|
+
max = np.maximum(max,max_lower_bounder)
|
|
151
|
+
data = (data-min)/(max-min+1e-8)
|
|
152
|
+
if C>3:
|
|
153
|
+
data = np.expand_dims(data,axis=1)
|
|
154
|
+
tb.add_images(name,data,global_step)
|
|
155
|
+
else:
|
|
156
|
+
if C==2:
|
|
157
|
+
_,H,W = data.shape
|
|
158
|
+
zeros = np.zeros([1,H,W],dtype=data.dtype)
|
|
159
|
+
data = np.concatenate([data,zeros],axis=0)
|
|
160
|
+
tb.add_image(name,data,global_step)
|
|
161
|
+
|
|
162
|
+
def log_heatmap_on_img(tb,name,img,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
|
|
163
|
+
'''
|
|
164
|
+
img: [H,W,C] (0~255)
|
|
165
|
+
heat_map: [C,H,W]
|
|
166
|
+
'''
|
|
167
|
+
heat_map = heat_map.astype(np.float32)
|
|
168
|
+
min = np.min(heat_map)
|
|
169
|
+
if min_upper_bounder is not None:
|
|
170
|
+
min = np.minimum(min,min_upper_bounder)
|
|
171
|
+
max = np.max(heat_map)
|
|
172
|
+
if max_lower_bounder is not None:
|
|
173
|
+
max = np.maximum(max,max_lower_bounder)
|
|
174
|
+
heat_map = (heat_map-min)/(max-min+1e-8)
|
|
175
|
+
img = odv.try_draw_rgb_heatmap_on_image(image=img,
|
|
176
|
+
scores=heat_map)
|
|
177
|
+
tb.add_image(name,img,global_step,dataformats="HWC")
|
|
178
|
+
|
|
179
|
+
def log_heatmap(tb,name,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
|
|
180
|
+
'''
|
|
181
|
+
heat_map: [C,H,W]
|
|
182
|
+
'''
|
|
183
|
+
heat_map = heat_map.astype(np.float32)
|
|
184
|
+
min = np.min(heat_map)
|
|
185
|
+
if min_upper_bounder is not None:
|
|
186
|
+
min = np.minimum(min,min_upper_bounder)
|
|
187
|
+
max = np.max(heat_map)
|
|
188
|
+
if max_lower_bounder is not None:
|
|
189
|
+
max = np.maximum(max,max_lower_bounder)
|
|
190
|
+
heat_map = (heat_map-min)/(max-min+1e-8)
|
|
191
|
+
img = odv.try_draw_rgb_heatmap_on_image(image=np.zeros([heat_map.shape[1],heat_map.shape[2],3],dtype=np.uint8),
|
|
192
|
+
color_pos=(255,0,0),
|
|
193
|
+
color_neg=(0,0,255),
|
|
194
|
+
scores=heat_map,alpha=1.0)
|
|
195
|
+
tb.add_image(name,img,global_step,dataformats="HWC")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def log_heatmap_on_imgv2(tb,name,img,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
|
|
199
|
+
'''
|
|
200
|
+
使用更复杂的伪彩色
|
|
201
|
+
img: [H,W,C] (0~255)
|
|
202
|
+
heat_map: [C,H,W]
|
|
203
|
+
'''
|
|
204
|
+
heat_map = heat_map.astype(np.float32)
|
|
205
|
+
min = np.min(heat_map)
|
|
206
|
+
if min_upper_bounder is not None:
|
|
207
|
+
min = np.minimum(min,min_upper_bounder)
|
|
208
|
+
max = np.max(heat_map)
|
|
209
|
+
if max_lower_bounder is not None:
|
|
210
|
+
max = np.maximum(max,max_lower_bounder)
|
|
211
|
+
heat_map = (heat_map-min)/(max-min+1e-8)
|
|
212
|
+
img = odv.try_draw_rgb_heatmap_on_imagev2(image=img,
|
|
213
|
+
palette=[(0,(0,0,0)),(0.5,(0,0,0)),(1.0,(255,0,0))],
|
|
214
|
+
scores=heat_map)
|
|
215
|
+
tb.add_image(name,img,global_step,dataformats="HWC")
|
|
216
|
+
|
|
217
|
+
def log_heatmapv2(tb,name,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
|
|
218
|
+
'''
|
|
219
|
+
heat_map: [C,H,W]
|
|
220
|
+
使用更复杂的伪彩色
|
|
221
|
+
'''
|
|
222
|
+
heat_map = heat_map.astype(np.float32)
|
|
223
|
+
heat_map = np.sum(heat_map,axis=0,keepdims=False)
|
|
224
|
+
min = np.min(heat_map)
|
|
225
|
+
if min_upper_bounder is not None:
|
|
226
|
+
min = np.minimum(min,min_upper_bounder)
|
|
227
|
+
max = np.max(heat_map)
|
|
228
|
+
if max_lower_bounder is not None:
|
|
229
|
+
max = np.maximum(max,max_lower_bounder)
|
|
230
|
+
heat_map = (heat_map-min)/(max-min+1e-8)
|
|
231
|
+
|
|
232
|
+
palette=[(0,(0,0,255)),(0.5,(255,255,255)),(1.0,(255,0,0))]
|
|
233
|
+
img = bwmli.pseudocolor_img(img=heat_map,palette=palette,auto_norm=False)
|
|
234
|
+
img = img.astype(np.uint8)
|
|
235
|
+
tb.add_image(name,img,global_step,dataformats="HWC")
|
|
236
|
+
|
|
237
|
+
def add_video_with_label(tb,name,video,label,global_step,fps=4,font_scale=1.2):
|
|
238
|
+
'''
|
|
239
|
+
Args:
|
|
240
|
+
tb:
|
|
241
|
+
name:
|
|
242
|
+
video: (N, T, C, H, W)
|
|
243
|
+
label:
|
|
244
|
+
global_step:
|
|
245
|
+
fps:
|
|
246
|
+
font_scale:
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
|
|
250
|
+
'''
|
|
251
|
+
if isinstance(video,torch.Tensor):
|
|
252
|
+
video = video.numpy()
|
|
253
|
+
video = video.transpose(0,1,3,4,2)
|
|
254
|
+
video = np.ascontiguousarray(video)
|
|
255
|
+
if label is not None:
|
|
256
|
+
for i in range(video.shape[0]):
|
|
257
|
+
l = label[i]
|
|
258
|
+
for j in range(video.shape[1]):
|
|
259
|
+
_draw_text_on_image(video[i,j],l,font_scale=font_scale)
|
|
260
|
+
#video (N,T,H,W,C)
|
|
261
|
+
video = video.transpose(0,1,4,2,3)
|
|
262
|
+
tb.add_video(name,video,global_step)
|
|
263
|
+
|
|
264
|
+
def log_mask(tb,tag,images,masks,step,color_map,img_min=None,img_max=None,ignore_label=255,max_images=4,save_raw_imgs=False):
|
|
265
|
+
images = images[:max_images]
|
|
266
|
+
masks = masks[:max_images]
|
|
267
|
+
if img_min or img_max is None:
|
|
268
|
+
img_min = torch.min(images)
|
|
269
|
+
img_max = torch.max(images)
|
|
270
|
+
images = (images-img_min)*255/(img_max-img_min+1e-8)
|
|
271
|
+
images = torch.clip(images,0,255)
|
|
272
|
+
images = images.to(torch.uint8)
|
|
273
|
+
images = images.permute(0,2,3,1).cpu().numpy()
|
|
274
|
+
masks = masks.cpu().numpy()
|
|
275
|
+
res_imgs = []
|
|
276
|
+
for img,msk in zip(images,masks):
|
|
277
|
+
r_img = odv.draw_semantic_on_image(img,msk,color_map,ignored_label=ignore_label)
|
|
278
|
+
res_imgs.append(r_img)
|
|
279
|
+
res_images = np.stack(res_imgs,axis=0)
|
|
280
|
+
tb.add_images(tag,res_images,step,dataformats='NHWC')
|
|
281
|
+
if save_raw_imgs:
|
|
282
|
+
tb.add_images(tag+"_raw",images,step,dataformats='NHWC')
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def log_optimizer(tb,optimizer,step,name=""):
|
|
286
|
+
for i,data in enumerate(optimizer.param_groups):
|
|
287
|
+
bname = f"{name} optimizer/{i}_{len(data['params'])}"
|
|
288
|
+
tb.add_scalar(bname+"_lr",data['lr'],step)
|
|
289
|
+
tb.add_scalar(bname+"_wd",data['weight_decay'],step)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def log_imgs_with_bboxes(tb,name,imgs,targets,step,max_imgs=None):
|
|
293
|
+
'''
|
|
294
|
+
imgs: [B,C,H,W] [0-255]
|
|
295
|
+
targets: [B,5] [label,x0,y0,x1,y1]
|
|
296
|
+
'''
|
|
297
|
+
if max_imgs is not None and max_imgs>0:
|
|
298
|
+
imgs = imgs[:max_imgs]
|
|
299
|
+
targets = targets[:max_imgs]
|
|
300
|
+
|
|
301
|
+
if torch.is_tensor(imgs):
|
|
302
|
+
imgs = imgs.detach().cpu().numpy().astype(np.uint8)
|
|
303
|
+
if torch.is_tensor(targets):
|
|
304
|
+
targets = targets.detach().cpu().numpy()
|
|
305
|
+
|
|
306
|
+
bboxes = targets[...,1:5]
|
|
307
|
+
labels = targets[...,0].astype(np.int32)
|
|
308
|
+
bboxes = odb.npchangexyorder(bboxes)
|
|
309
|
+
|
|
310
|
+
res_imgs = []
|
|
311
|
+
for i in range(imgs.shape[0]):
|
|
312
|
+
img = imgs[i]
|
|
313
|
+
img = np.ascontiguousarray(np.transpose(img,[1,2,0]))
|
|
314
|
+
cur_bboxes = bboxes[i]
|
|
315
|
+
cur_labels = labels[i]
|
|
316
|
+
bboxes_nr = np.count_nonzero(np.sum(cur_bboxes,axis=-1)>0)
|
|
317
|
+
cur_bboxes = cur_bboxes[:bboxes_nr]
|
|
318
|
+
cur_labels = cur_labels[:bboxes_nr]
|
|
319
|
+
img = odv.draw_bboxes(img,cur_labels,bboxes=cur_bboxes,is_relative_coordinate=False)
|
|
320
|
+
res_imgs.append(img)
|
|
321
|
+
|
|
322
|
+
res_imgs = np.array(res_imgs)
|
|
323
|
+
res_imgs = np.ascontiguousarray(res_imgs)
|
|
324
|
+
tb.add_images(name,res_imgs,step,dataformats='NHWC')
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def log_semantic_seg(tb,name,img,seg,global_step,alpha=0.4,ignore_idx=255):
|
|
328
|
+
img = odv.draw_seg_on_img(img,seg,alpha=alpha,ignore_idx=ignore_idx)
|
|
329
|
+
tb.add_image(name,img,global_step,dataformats="HWC")
|
|
330
|
+
|
|
331
|
+
|