python-wml 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python-wml might be problematic. Click here for more details.

Files changed (164) hide show
  1. python_wml-3.0.0.dist-info/LICENSE +23 -0
  2. python_wml-3.0.0.dist-info/METADATA +51 -0
  3. python_wml-3.0.0.dist-info/RECORD +164 -0
  4. python_wml-3.0.0.dist-info/WHEEL +5 -0
  5. python_wml-3.0.0.dist-info/top_level.txt +1 -0
  6. wml/__init__.py +0 -0
  7. wml/basic_data_def/__init__.py +2 -0
  8. wml/basic_data_def/detection_data_def.py +279 -0
  9. wml/basic_data_def/io_data_def.py +2 -0
  10. wml/basic_img_utils.py +816 -0
  11. wml/img_patch.py +92 -0
  12. wml/img_utils.py +571 -0
  13. wml/iotoolkit/__init__.py +17 -0
  14. wml/iotoolkit/aic_keypoint.py +115 -0
  15. wml/iotoolkit/baidu_mask_toolkit.py +244 -0
  16. wml/iotoolkit/base_dataset.py +210 -0
  17. wml/iotoolkit/bboxes_statistics.py +515 -0
  18. wml/iotoolkit/build.py +0 -0
  19. wml/iotoolkit/cityscapes_toolkit.py +183 -0
  20. wml/iotoolkit/classification_data_statistics.py +25 -0
  21. wml/iotoolkit/coco_data_fwd.py +225 -0
  22. wml/iotoolkit/coco_keypoints.py +118 -0
  23. wml/iotoolkit/coco_keypoints_fmt2.py +103 -0
  24. wml/iotoolkit/coco_toolkit.py +397 -0
  25. wml/iotoolkit/coco_wholebody.py +269 -0
  26. wml/iotoolkit/common.py +108 -0
  27. wml/iotoolkit/crowd_pose.py +146 -0
  28. wml/iotoolkit/fast_labelme.py +110 -0
  29. wml/iotoolkit/image_folder.py +95 -0
  30. wml/iotoolkit/imgs_cache.py +58 -0
  31. wml/iotoolkit/imgs_reader_mt.py +73 -0
  32. wml/iotoolkit/labelme_base.py +102 -0
  33. wml/iotoolkit/labelme_json_to_img.py +49 -0
  34. wml/iotoolkit/labelme_toolkit.py +117 -0
  35. wml/iotoolkit/labelme_toolkit_fwd.py +733 -0
  36. wml/iotoolkit/labelmemckeypoints_dataset.py +169 -0
  37. wml/iotoolkit/lspet.py +48 -0
  38. wml/iotoolkit/mapillary_vistas_toolkit.py +269 -0
  39. wml/iotoolkit/mat_data.py +90 -0
  40. wml/iotoolkit/mckeypoints_statistics.py +28 -0
  41. wml/iotoolkit/mot_datasets.py +62 -0
  42. wml/iotoolkit/mpii.py +108 -0
  43. wml/iotoolkit/npmckeypoints_dataset.py +164 -0
  44. wml/iotoolkit/o365_to_coco.py +136 -0
  45. wml/iotoolkit/object365_toolkit.py +156 -0
  46. wml/iotoolkit/object365v2_toolkit.py +71 -0
  47. wml/iotoolkit/pascal_voc_data.py +51 -0
  48. wml/iotoolkit/pascal_voc_toolkit.py +194 -0
  49. wml/iotoolkit/pascal_voc_toolkit_fwd.py +473 -0
  50. wml/iotoolkit/penn_action.py +57 -0
  51. wml/iotoolkit/rawframe_dataset.py +129 -0
  52. wml/iotoolkit/rewrite_pascal_voc.py +28 -0
  53. wml/iotoolkit/semantic_data.py +49 -0
  54. wml/iotoolkit/split_file_by_type.py +29 -0
  55. wml/iotoolkit/sports_mot_datasets.py +78 -0
  56. wml/iotoolkit/vis_objectdetection_dataset.py +70 -0
  57. wml/iotoolkit/vis_torch_data.py +39 -0
  58. wml/iotoolkit/yolo_toolkit.py +38 -0
  59. wml/object_detection2/__init__.py +4 -0
  60. wml/object_detection2/basic_visualization.py +37 -0
  61. wml/object_detection2/bboxes.py +812 -0
  62. wml/object_detection2/data_process_toolkit.py +146 -0
  63. wml/object_detection2/keypoints.py +292 -0
  64. wml/object_detection2/mask.py +120 -0
  65. wml/object_detection2/metrics/__init__.py +3 -0
  66. wml/object_detection2/metrics/build.py +15 -0
  67. wml/object_detection2/metrics/classifier_toolkit.py +440 -0
  68. wml/object_detection2/metrics/common.py +71 -0
  69. wml/object_detection2/metrics/mckps_toolkit.py +338 -0
  70. wml/object_detection2/metrics/toolkit.py +1953 -0
  71. wml/object_detection2/npod_toolkit.py +361 -0
  72. wml/object_detection2/odtools.py +243 -0
  73. wml/object_detection2/standard_names.py +75 -0
  74. wml/object_detection2/visualization.py +956 -0
  75. wml/object_detection2/wmath.py +34 -0
  76. wml/semantic/__init__.py +0 -0
  77. wml/semantic/basic_toolkit.py +65 -0
  78. wml/semantic/mask_utils.py +156 -0
  79. wml/semantic/semantic_test.py +21 -0
  80. wml/semantic/structures.py +1 -0
  81. wml/semantic/toolkit.py +105 -0
  82. wml/semantic/visualization_utils.py +658 -0
  83. wml/threadtoolkit.py +50 -0
  84. wml/walgorithm.py +228 -0
  85. wml/wcollections.py +212 -0
  86. wml/wfilesystem.py +487 -0
  87. wml/wml_utils.py +657 -0
  88. wml/wstructures/__init__.py +4 -0
  89. wml/wstructures/common.py +9 -0
  90. wml/wstructures/keypoints_train_toolkit.py +149 -0
  91. wml/wstructures/kps_structures.py +579 -0
  92. wml/wstructures/mask_structures.py +1161 -0
  93. wml/wtorch/__init__.py +8 -0
  94. wml/wtorch/bboxes.py +104 -0
  95. wml/wtorch/classes_suppression.py +24 -0
  96. wml/wtorch/conv_module.py +181 -0
  97. wml/wtorch/conv_ws.py +144 -0
  98. wml/wtorch/data/__init__.py +16 -0
  99. wml/wtorch/data/_utils/__init__.py +45 -0
  100. wml/wtorch/data/_utils/collate.py +183 -0
  101. wml/wtorch/data/_utils/fetch.py +47 -0
  102. wml/wtorch/data/_utils/pin_memory.py +121 -0
  103. wml/wtorch/data/_utils/signal_handling.py +72 -0
  104. wml/wtorch/data/_utils/worker.py +227 -0
  105. wml/wtorch/data/base_data_loader_iter.py +93 -0
  106. wml/wtorch/data/dataloader.py +501 -0
  107. wml/wtorch/data/datapipes/__init__.py +1 -0
  108. wml/wtorch/data/datapipes/iter/__init__.py +12 -0
  109. wml/wtorch/data/datapipes/iter/batch.py +126 -0
  110. wml/wtorch/data/datapipes/iter/callable.py +92 -0
  111. wml/wtorch/data/datapipes/iter/listdirfiles.py +37 -0
  112. wml/wtorch/data/datapipes/iter/loadfilesfromdisk.py +30 -0
  113. wml/wtorch/data/datapipes/iter/readfilesfromtar.py +60 -0
  114. wml/wtorch/data/datapipes/iter/readfilesfromzip.py +63 -0
  115. wml/wtorch/data/datapipes/iter/sampler.py +94 -0
  116. wml/wtorch/data/datapipes/utils/__init__.py +0 -0
  117. wml/wtorch/data/datapipes/utils/common.py +65 -0
  118. wml/wtorch/data/dataset.py +354 -0
  119. wml/wtorch/data/datasets/__init__.py +4 -0
  120. wml/wtorch/data/datasets/common.py +53 -0
  121. wml/wtorch/data/datasets/listdirfilesdataset.py +36 -0
  122. wml/wtorch/data/datasets/loadfilesfromdiskdataset.py +30 -0
  123. wml/wtorch/data/distributed.py +135 -0
  124. wml/wtorch/data/multi_processing_data_loader_iter.py +866 -0
  125. wml/wtorch/data/sampler.py +267 -0
  126. wml/wtorch/data/single_process_data_loader_iter.py +24 -0
  127. wml/wtorch/data/test_data_loader.py +26 -0
  128. wml/wtorch/dataset_toolkit.py +67 -0
  129. wml/wtorch/depthwise_separable_conv_module.py +98 -0
  130. wml/wtorch/dist.py +591 -0
  131. wml/wtorch/dropblock/__init__.py +6 -0
  132. wml/wtorch/dropblock/dropblock.py +228 -0
  133. wml/wtorch/dropblock/dropout.py +40 -0
  134. wml/wtorch/dropblock/scheduler.py +48 -0
  135. wml/wtorch/ema.py +61 -0
  136. wml/wtorch/fc_module.py +73 -0
  137. wml/wtorch/functional.py +34 -0
  138. wml/wtorch/iter_dataset.py +26 -0
  139. wml/wtorch/loss.py +69 -0
  140. wml/wtorch/nets/__init__.py +0 -0
  141. wml/wtorch/nets/ckpt_toolkit.py +219 -0
  142. wml/wtorch/nets/fpn.py +276 -0
  143. wml/wtorch/nets/hrnet/__init__.py +0 -0
  144. wml/wtorch/nets/hrnet/config.py +2 -0
  145. wml/wtorch/nets/hrnet/hrnet.py +494 -0
  146. wml/wtorch/nets/misc.py +249 -0
  147. wml/wtorch/nets/resnet/__init__.py +0 -0
  148. wml/wtorch/nets/resnet/layers/__init__.py +17 -0
  149. wml/wtorch/nets/resnet/layers/aspp.py +144 -0
  150. wml/wtorch/nets/resnet/layers/batch_norm.py +231 -0
  151. wml/wtorch/nets/resnet/layers/blocks.py +111 -0
  152. wml/wtorch/nets/resnet/layers/wrappers.py +110 -0
  153. wml/wtorch/nets/resnet/r50_config.py +38 -0
  154. wml/wtorch/nets/resnet/resnet.py +691 -0
  155. wml/wtorch/nets/shape_spec.py +20 -0
  156. wml/wtorch/nets/simple_fpn.py +101 -0
  157. wml/wtorch/nms.py +109 -0
  158. wml/wtorch/nn.py +896 -0
  159. wml/wtorch/ocr_block.py +193 -0
  160. wml/wtorch/summary.py +331 -0
  161. wml/wtorch/train_toolkit.py +603 -0
  162. wml/wtorch/transformer_blocks.py +266 -0
  163. wml/wtorch/utils.py +719 -0
  164. wml/wtorch/wlr_scheduler.py +100 -0
@@ -0,0 +1,193 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ def Norm2d(in_channels, **kwargs):
6
+ """
7
+ Custom Norm Function to allow flexible switching
8
+ """
9
+ layer = nn.BatchNorm2d
10
+ normalization_layer = layer(in_channels, **kwargs)
11
+ return normalization_layer
12
+
13
+ def BNReLU(ch):
14
+ return nn.Sequential(
15
+ Norm2d(ch),
16
+ nn.ReLU())
17
+
18
+ class SpatialGather_Module(nn.Module):
19
+ """
20
+ Aggregate the context features according to the initial
21
+ predicted probability distribution.
22
+ Employ the soft-weighted method to aggregate the context.
23
+
24
+ Output:
25
+ The correlation of every class map with every feature map
26
+ shape = [n, num_feats, num_classes, 1]
27
+
28
+
29
+ """
30
+ def __init__(self, cls_num=0, scale=1):
31
+ super().__init__()
32
+ self.cls_num = cls_num
33
+ self.scale = scale
34
+
35
+ def forward(self, feats, probs):
36
+ batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(2), \
37
+ probs.size(3)
38
+
39
+ # each class image now a vector
40
+ probs = probs.view(batch_size, c, -1)
41
+ feats = feats.view(batch_size, feats.size(1), -1)
42
+
43
+ feats = feats.permute(0, 2, 1) # batch x hw x c
44
+ probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
45
+ ocr_context = torch.matmul(probs, feats)
46
+ ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3)
47
+ return ocr_context
48
+
49
+
50
+ class ObjectAttentionBlock(nn.Module):
51
+ '''
52
+ The basic implementation for object context block
53
+ Input:
54
+ N X C X H X W
55
+ Parameters:
56
+ in_channels : the dimension of the input feature map
57
+ key_channels : the dimension after the key/query transform
58
+ scale : choose the scale to downsample the input feature
59
+ maps (save memory cost)
60
+ Return:
61
+ N X C X H X W
62
+ '''
63
+ def __init__(self, in_channels, key_channels, scale=1):
64
+ super().__init__()
65
+ self.scale = scale
66
+ self.in_channels = in_channels
67
+ self.key_channels = key_channels
68
+ self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
69
+ self.f_pixel = nn.Sequential(
70
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
71
+ kernel_size=1, stride=1, padding=0, bias=False),
72
+ BNReLU(self.key_channels),
73
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
74
+ kernel_size=1, stride=1, padding=0, bias=False),
75
+ BNReLU(self.key_channels),
76
+ )
77
+ self.f_object = nn.Sequential(
78
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
79
+ kernel_size=1, stride=1, padding=0, bias=False),
80
+ BNReLU(self.key_channels),
81
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
82
+ kernel_size=1, stride=1, padding=0, bias=False),
83
+ BNReLU(self.key_channels),
84
+ )
85
+ self.f_down = nn.Sequential(
86
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
87
+ kernel_size=1, stride=1, padding=0, bias=False),
88
+ BNReLU(self.key_channels),
89
+ )
90
+ self.f_up = nn.Sequential(
91
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
92
+ kernel_size=1, stride=1, padding=0, bias=False),
93
+ BNReLU(self.in_channels),
94
+ )
95
+
96
+ def forward(self, x, proxy):
97
+ batch_size, h, w = x.size(0), x.size(2), x.size(3)
98
+ if self.scale > 1:
99
+ x = self.pool(x)
100
+
101
+ query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
102
+ query = query.permute(0, 2, 1)
103
+ key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
104
+ value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
105
+ value = value.permute(0, 2, 1)
106
+
107
+ sim_map = torch.matmul(query, key)
108
+ sim_map = (self.key_channels**-.5) * sim_map
109
+ sim_map = F.softmax(sim_map, dim=-1)
110
+
111
+ # add bg context ...
112
+ context = torch.matmul(sim_map, value)
113
+ context = context.permute(0, 2, 1).contiguous()
114
+ context = context.view(batch_size, self.key_channels, *x.size()[2:])
115
+ context = self.f_up(context)
116
+ if self.scale > 1:
117
+ context = F.interpolate(input=context, size=(h, w), mode='bilinear',
118
+ align_corners=False)
119
+
120
+ return context
121
+
122
+
123
+ class SpatialOCR_Module(nn.Module):
124
+ """
125
+ Implementation of the OCR module:
126
+ We aggregate the global object representation to update the representation
127
+ for each pixel.
128
+ """
129
+ def __init__(self, in_channels, key_channels, out_channels, scale=1,
130
+ dropout=0.1):
131
+ super().__init__()
132
+ self.object_context_block = ObjectAttentionBlock(in_channels,
133
+ key_channels,
134
+ scale)
135
+ _in_channels = 2 * in_channels
136
+
137
+ self.conv_bn_dropout = nn.Sequential(
138
+ nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0,
139
+ bias=False),
140
+ BNReLU(out_channels),
141
+ nn.Dropout2d(dropout)
142
+ )
143
+
144
+ def forward(self, feats, proxy_feats):
145
+ context = self.object_context_block(feats, proxy_feats)
146
+
147
+ output = self.conv_bn_dropout(torch.cat([context, feats], 1))
148
+
149
+ return output
150
+
151
+ class OCRBlock(nn.Module):
152
+ """
153
+ Some of the code in this class is borrowed from:
154
+ https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/HRNet-OCR
155
+ """
156
+ def __init__(self, in_channels,num_classes,key_channels=256,mid_channel=512):
157
+ super().__init__()
158
+
159
+ ocr_mid_channels = mid_channel
160
+ ocr_key_channels = key_channels
161
+ num_classes = num_classes
162
+
163
+ self.conv3x3_ocr = nn.Sequential(
164
+ nn.Conv2d(in_channels, ocr_mid_channels,
165
+ kernel_size=3, stride=1, padding=1),
166
+ BNReLU(ocr_mid_channels),
167
+ )
168
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
169
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
170
+ key_channels=ocr_key_channels,
171
+ out_channels=ocr_mid_channels,
172
+ scale=1,
173
+ dropout=0.05,
174
+ )
175
+ self.cls_head = nn.Conv2d(
176
+ ocr_mid_channels, num_classes, kernel_size=1, stride=1, padding=0,
177
+ bias=True)
178
+
179
+ self.aux_head = nn.Sequential(
180
+ nn.Conv2d(in_channels, in_channels,
181
+ kernel_size=1, stride=1, padding=0),
182
+ BNReLU(in_channels),
183
+ nn.Conv2d(in_channels, num_classes,
184
+ kernel_size=1, stride=1, padding=0, bias=True)
185
+ )
186
+
187
+ def forward(self, high_level_features):
188
+ feats = self.conv3x3_ocr(high_level_features)
189
+ aux_out = self.aux_head(high_level_features)
190
+ context = self.ocr_gather_head(feats, aux_out)
191
+ ocr_feats = self.ocr_distri_head(feats, context)
192
+ cls_out = self.cls_head(ocr_feats)
193
+ return cls_out, aux_out, ocr_feats
wml/wtorch/summary.py ADDED
@@ -0,0 +1,331 @@
1
+ import torch
2
+ from collections import Iterable
3
+ import wml.object_detection2.visualization as odv
4
+ import random
5
+ import numpy as np
6
+ import cv2
7
+ import wml.object_detection2.bboxes as odb
8
+ import wml.basic_img_utils as bwmli
9
+ #from torch.utils.tensorboard import SummaryWriter
10
+ #SummaryWriter.add_image()
11
+ #tb.add_images
12
+
13
+ def _draw_text_on_image(img,text,font_scale=1.2,color=(0.,255.,0.),pos=None):
14
+ if isinstance(text,bytes):
15
+ text = str(text,encoding="utf-8")
16
+ if not isinstance(text,str):
17
+ text = str(text)
18
+ thickness = 2
19
+ size = cv2.getTextSize(text,fontFace=cv2.FONT_HERSHEY_COMPLEX,fontScale=font_scale,thickness=thickness)
20
+ if pos is None:
21
+ pos = (0,(img.shape[0]+size[0][1])//2)
22
+ cv2.putText(img, text, pos, cv2.FONT_HERSHEY_COMPLEX, fontScale=font_scale, color=color, thickness=thickness)
23
+ return img
24
+
25
+ def log_all_variable(tb,net:torch.nn.Module,global_step):
26
+ try:
27
+ for name,param in net.named_parameters():
28
+ if 'bias' in name:
29
+ name = "BIAS/"+name
30
+ elif "." in name:
31
+ name = name.replace(".","/",1)
32
+ if param.numel()>1:
33
+ tb.add_histogram(name,param,global_step)
34
+ else:
35
+ tb.add_scalar(name,param,global_step)
36
+
37
+ data = net.state_dict()
38
+ for name in data:
39
+ if "running" in name:
40
+ param = data[name]
41
+ if param.numel()>1:
42
+ tb.add_histogram("BN/"+name,param,global_step)
43
+ else:
44
+ tb.add_scalar("BN/"+name,param,global_step)
45
+ except Exception as e:
46
+ print("ERROR:",e)
47
+
48
+ def log_all_variable_min_max(tb,net:torch.nn.Module,global_step):
49
+ try:
50
+ for name,param in net.named_parameters():
51
+ if 'bias' in name:
52
+ name = "BIAS/"+name
53
+ elif "." in name:
54
+ name = name.replace(".","/",1)
55
+ if param.numel()>1:
56
+ std,mean = torch.std_mean(param)
57
+ tb.add_scalar(name+"_min",torch.min(param),global_step)
58
+ tb.add_scalar(name+"_max",torch.max(param),global_step)
59
+ tb.add_scalar(name+"_mean",mean,global_step)
60
+ tb.add_scalar(name+"_std",std,global_step)
61
+ else:
62
+ tb.add_scalar(name,param,global_step)
63
+
64
+ data = net.state_dict()
65
+ for name in data:
66
+ if "running" in name:
67
+ param = data[name]
68
+ if param.numel()>1:
69
+ tb.add_histogram("BN/"+name,param,global_step)
70
+ else:
71
+ tb.add_scalar("BN/"+name,param,global_step)
72
+ except Exception as e:
73
+ print("ERROR:",e)
74
+
75
+ def log_basic_info(tb,name,value:torch.Tensor,global_step):
76
+ if value.numel()>1:
77
+ min_v = torch.min(value)
78
+ max_v = torch.max(value)
79
+ mean_v = torch.mean(value)
80
+ std_v = torch.std(value)
81
+ tb.add_scalar(name+"/min",min_v,global_step)
82
+ tb.add_scalar(name+"/max",max_v,global_step)
83
+ tb.add_scalar(name+"/mean",mean_v,global_step)
84
+ tb.add_scalar(name+"/std",std_v,global_step)
85
+ else:
86
+ tb.add_scalar(name,value,global_step)
87
+
88
+ def add_image_with_label(tb,name,image,label,global_step):
89
+ label = str(label)
90
+ image = image.numpy()
91
+ image = image.transpose(1,2,0)
92
+ image = _draw_text_on_image(image,label)
93
+ image = image.transpose(2,0,1)
94
+ tb.add_image(name,image,global_step)
95
+
96
+ def add_images_with_label(tb,name,image,label,global_step,font_scale=1.2):
97
+ if isinstance(image,torch.Tensor):
98
+ image = image.numpy()
99
+ image = image.transpose(0,2,3,1)
100
+ image = np.ascontiguousarray(image)
101
+ if not isinstance(label,Iterable):
102
+ label = str(label)
103
+ image[0] = _draw_text_on_image(image[0], label,font_scale=font_scale)
104
+ elif len(label) == 1:
105
+ label = str(label[0])
106
+ image[0] = _draw_text_on_image(image[0],label,font_scale=font_scale)
107
+ elif len(label) == image.shape[0]:
108
+ for i in range(len(label)):
109
+ image[i] = _draw_text_on_image(image[i], str(label[i]),font_scale=font_scale)
110
+ else:
111
+ print(f"ERROR label {label}")
112
+ return
113
+
114
+ image = image.transpose(0,3,1,2)
115
+ tb.add_images(name,image,global_step)
116
+
117
+ def log_feature_map(tb,name,tensor,global_step,random_index=True):
118
+ '''
119
+ tensor: [B,C,H,W]
120
+ '''
121
+ if isinstance(tensor,torch.Tensor):
122
+ tensor = tensor.cpu().detach().numpy()
123
+
124
+ if random_index:
125
+ i = random.randint(0,tensor.shape[0]-1)
126
+ else:
127
+ i = 0
128
+ data = tensor[i]
129
+ data = np.expand_dims(data,axis=1)
130
+ min = np.min(data)
131
+ max = np.max(data)
132
+ data = (data-min)/(max-min+1e-8)
133
+ tb.add_images(name,data,global_step)
134
+
135
+ def try_log_rgb_feature_map(tb,name,tensor,global_step,random_index=True,min_upper_bounder=None,max_lower_bounder=None):
136
+ if isinstance(tensor,torch.Tensor):
137
+ tensor = tensor.cpu().detach().numpy()
138
+
139
+ if random_index:
140
+ i = random.randint(0,tensor.shape[0]-1)
141
+ else:
142
+ i = 0
143
+ C = tensor.shape[1]
144
+ data = tensor[i]
145
+ min = np.min(data)
146
+ if min_upper_bounder is not None:
147
+ min = np.minimum(min,min_upper_bounder)
148
+ max = np.max(data)
149
+ if max_lower_bounder is not None:
150
+ max = np.maximum(max,max_lower_bounder)
151
+ data = (data-min)/(max-min+1e-8)
152
+ if C>3:
153
+ data = np.expand_dims(data,axis=1)
154
+ tb.add_images(name,data,global_step)
155
+ else:
156
+ if C==2:
157
+ _,H,W = data.shape
158
+ zeros = np.zeros([1,H,W],dtype=data.dtype)
159
+ data = np.concatenate([data,zeros],axis=0)
160
+ tb.add_image(name,data,global_step)
161
+
162
+ def log_heatmap_on_img(tb,name,img,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
163
+ '''
164
+ img: [H,W,C] (0~255)
165
+ heat_map: [C,H,W]
166
+ '''
167
+ heat_map = heat_map.astype(np.float32)
168
+ min = np.min(heat_map)
169
+ if min_upper_bounder is not None:
170
+ min = np.minimum(min,min_upper_bounder)
171
+ max = np.max(heat_map)
172
+ if max_lower_bounder is not None:
173
+ max = np.maximum(max,max_lower_bounder)
174
+ heat_map = (heat_map-min)/(max-min+1e-8)
175
+ img = odv.try_draw_rgb_heatmap_on_image(image=img,
176
+ scores=heat_map)
177
+ tb.add_image(name,img,global_step,dataformats="HWC")
178
+
179
+ def log_heatmap(tb,name,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
180
+ '''
181
+ heat_map: [C,H,W]
182
+ '''
183
+ heat_map = heat_map.astype(np.float32)
184
+ min = np.min(heat_map)
185
+ if min_upper_bounder is not None:
186
+ min = np.minimum(min,min_upper_bounder)
187
+ max = np.max(heat_map)
188
+ if max_lower_bounder is not None:
189
+ max = np.maximum(max,max_lower_bounder)
190
+ heat_map = (heat_map-min)/(max-min+1e-8)
191
+ img = odv.try_draw_rgb_heatmap_on_image(image=np.zeros([heat_map.shape[1],heat_map.shape[2],3],dtype=np.uint8),
192
+ color_pos=(255,0,0),
193
+ color_neg=(0,0,255),
194
+ scores=heat_map,alpha=1.0)
195
+ tb.add_image(name,img,global_step,dataformats="HWC")
196
+
197
+
198
+ def log_heatmap_on_imgv2(tb,name,img,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
199
+ '''
200
+ 使用更复杂的伪彩色
201
+ img: [H,W,C] (0~255)
202
+ heat_map: [C,H,W]
203
+ '''
204
+ heat_map = heat_map.astype(np.float32)
205
+ min = np.min(heat_map)
206
+ if min_upper_bounder is not None:
207
+ min = np.minimum(min,min_upper_bounder)
208
+ max = np.max(heat_map)
209
+ if max_lower_bounder is not None:
210
+ max = np.maximum(max,max_lower_bounder)
211
+ heat_map = (heat_map-min)/(max-min+1e-8)
212
+ img = odv.try_draw_rgb_heatmap_on_imagev2(image=img,
213
+ palette=[(0,(0,0,0)),(0.5,(0,0,0)),(1.0,(255,0,0))],
214
+ scores=heat_map)
215
+ tb.add_image(name,img,global_step,dataformats="HWC")
216
+
217
+ def log_heatmapv2(tb,name,heat_map,global_step,min_upper_bounder=None,max_lower_bounder=None):
218
+ '''
219
+ heat_map: [C,H,W]
220
+ 使用更复杂的伪彩色
221
+ '''
222
+ heat_map = heat_map.astype(np.float32)
223
+ heat_map = np.sum(heat_map,axis=0,keepdims=False)
224
+ min = np.min(heat_map)
225
+ if min_upper_bounder is not None:
226
+ min = np.minimum(min,min_upper_bounder)
227
+ max = np.max(heat_map)
228
+ if max_lower_bounder is not None:
229
+ max = np.maximum(max,max_lower_bounder)
230
+ heat_map = (heat_map-min)/(max-min+1e-8)
231
+
232
+ palette=[(0,(0,0,255)),(0.5,(255,255,255)),(1.0,(255,0,0))]
233
+ img = bwmli.pseudocolor_img(img=heat_map,palette=palette,auto_norm=False)
234
+ img = img.astype(np.uint8)
235
+ tb.add_image(name,img,global_step,dataformats="HWC")
236
+
237
+ def add_video_with_label(tb,name,video,label,global_step,fps=4,font_scale=1.2):
238
+ '''
239
+ Args:
240
+ tb:
241
+ name:
242
+ video: (N, T, C, H, W)
243
+ label:
244
+ global_step:
245
+ fps:
246
+ font_scale:
247
+
248
+ Returns:
249
+
250
+ '''
251
+ if isinstance(video,torch.Tensor):
252
+ video = video.numpy()
253
+ video = video.transpose(0,1,3,4,2)
254
+ video = np.ascontiguousarray(video)
255
+ if label is not None:
256
+ for i in range(video.shape[0]):
257
+ l = label[i]
258
+ for j in range(video.shape[1]):
259
+ _draw_text_on_image(video[i,j],l,font_scale=font_scale)
260
+ #video (N,T,H,W,C)
261
+ video = video.transpose(0,1,4,2,3)
262
+ tb.add_video(name,video,global_step)
263
+
264
+ def log_mask(tb,tag,images,masks,step,color_map,img_min=None,img_max=None,ignore_label=255,max_images=4,save_raw_imgs=False):
265
+ images = images[:max_images]
266
+ masks = masks[:max_images]
267
+ if img_min or img_max is None:
268
+ img_min = torch.min(images)
269
+ img_max = torch.max(images)
270
+ images = (images-img_min)*255/(img_max-img_min+1e-8)
271
+ images = torch.clip(images,0,255)
272
+ images = images.to(torch.uint8)
273
+ images = images.permute(0,2,3,1).cpu().numpy()
274
+ masks = masks.cpu().numpy()
275
+ res_imgs = []
276
+ for img,msk in zip(images,masks):
277
+ r_img = odv.draw_semantic_on_image(img,msk,color_map,ignored_label=ignore_label)
278
+ res_imgs.append(r_img)
279
+ res_images = np.stack(res_imgs,axis=0)
280
+ tb.add_images(tag,res_images,step,dataformats='NHWC')
281
+ if save_raw_imgs:
282
+ tb.add_images(tag+"_raw",images,step,dataformats='NHWC')
283
+
284
+
285
+ def log_optimizer(tb,optimizer,step,name=""):
286
+ for i,data in enumerate(optimizer.param_groups):
287
+ bname = f"{name} optimizer/{i}_{len(data['params'])}"
288
+ tb.add_scalar(bname+"_lr",data['lr'],step)
289
+ tb.add_scalar(bname+"_wd",data['weight_decay'],step)
290
+
291
+
292
+ def log_imgs_with_bboxes(tb,name,imgs,targets,step,max_imgs=None):
293
+ '''
294
+ imgs: [B,C,H,W] [0-255]
295
+ targets: [B,5] [label,x0,y0,x1,y1]
296
+ '''
297
+ if max_imgs is not None and max_imgs>0:
298
+ imgs = imgs[:max_imgs]
299
+ targets = targets[:max_imgs]
300
+
301
+ if torch.is_tensor(imgs):
302
+ imgs = imgs.detach().cpu().numpy().astype(np.uint8)
303
+ if torch.is_tensor(targets):
304
+ targets = targets.detach().cpu().numpy()
305
+
306
+ bboxes = targets[...,1:5]
307
+ labels = targets[...,0].astype(np.int32)
308
+ bboxes = odb.npchangexyorder(bboxes)
309
+
310
+ res_imgs = []
311
+ for i in range(imgs.shape[0]):
312
+ img = imgs[i]
313
+ img = np.ascontiguousarray(np.transpose(img,[1,2,0]))
314
+ cur_bboxes = bboxes[i]
315
+ cur_labels = labels[i]
316
+ bboxes_nr = np.count_nonzero(np.sum(cur_bboxes,axis=-1)>0)
317
+ cur_bboxes = cur_bboxes[:bboxes_nr]
318
+ cur_labels = cur_labels[:bboxes_nr]
319
+ img = odv.draw_bboxes(img,cur_labels,bboxes=cur_bboxes,is_relative_coordinate=False)
320
+ res_imgs.append(img)
321
+
322
+ res_imgs = np.array(res_imgs)
323
+ res_imgs = np.ascontiguousarray(res_imgs)
324
+ tb.add_images(name,res_imgs,step,dataformats='NHWC')
325
+
326
+
327
+ def log_semantic_seg(tb,name,img,seg,global_step,alpha=0.4,ignore_idx=255):
328
+ img = odv.draw_seg_on_img(img,seg,alpha=alpha,ignore_idx=ignore_idx)
329
+ tb.add_image(name,img,global_step,dataformats="HWC")
330
+
331
+