supervisely 6.73.410__py3-none-any.whl → 6.73.470__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of supervisely might be problematic. Click here for more details.

Files changed (190) hide show
  1. supervisely/__init__.py +136 -1
  2. supervisely/_utils.py +81 -0
  3. supervisely/annotation/json_geometries_map.py +2 -0
  4. supervisely/annotation/label.py +80 -3
  5. supervisely/api/annotation_api.py +9 -9
  6. supervisely/api/api.py +67 -43
  7. supervisely/api/app_api.py +72 -5
  8. supervisely/api/dataset_api.py +108 -33
  9. supervisely/api/entity_annotation/figure_api.py +113 -49
  10. supervisely/api/image_api.py +82 -0
  11. supervisely/api/module_api.py +10 -0
  12. supervisely/api/nn/deploy_api.py +15 -9
  13. supervisely/api/nn/ecosystem_models_api.py +201 -0
  14. supervisely/api/nn/neural_network_api.py +12 -3
  15. supervisely/api/pointcloud/pointcloud_api.py +38 -0
  16. supervisely/api/pointcloud/pointcloud_episode_annotation_api.py +3 -0
  17. supervisely/api/project_api.py +213 -6
  18. supervisely/api/task_api.py +11 -1
  19. supervisely/api/video/video_annotation_api.py +4 -2
  20. supervisely/api/video/video_api.py +79 -1
  21. supervisely/api/video/video_figure_api.py +24 -11
  22. supervisely/api/volume/volume_api.py +38 -0
  23. supervisely/app/__init__.py +1 -1
  24. supervisely/app/content.py +14 -6
  25. supervisely/app/fastapi/__init__.py +1 -0
  26. supervisely/app/fastapi/custom_static_files.py +1 -1
  27. supervisely/app/fastapi/multi_user.py +88 -0
  28. supervisely/app/fastapi/subapp.py +175 -42
  29. supervisely/app/fastapi/templating.py +1 -1
  30. supervisely/app/fastapi/websocket.py +77 -9
  31. supervisely/app/singleton.py +21 -0
  32. supervisely/app/v1/app_service.py +18 -2
  33. supervisely/app/v1/constants.py +7 -1
  34. supervisely/app/widgets/__init__.py +11 -1
  35. supervisely/app/widgets/agent_selector/template.html +1 -0
  36. supervisely/app/widgets/card/card.py +20 -0
  37. supervisely/app/widgets/dataset_thumbnail/dataset_thumbnail.py +11 -2
  38. supervisely/app/widgets/dataset_thumbnail/template.html +3 -1
  39. supervisely/app/widgets/deploy_model/deploy_model.py +750 -0
  40. supervisely/app/widgets/dialog/dialog.py +12 -0
  41. supervisely/app/widgets/dialog/template.html +2 -1
  42. supervisely/app/widgets/dropdown_checkbox_selector/__init__.py +0 -0
  43. supervisely/app/widgets/dropdown_checkbox_selector/dropdown_checkbox_selector.py +87 -0
  44. supervisely/app/widgets/dropdown_checkbox_selector/template.html +12 -0
  45. supervisely/app/widgets/ecosystem_model_selector/__init__.py +0 -0
  46. supervisely/app/widgets/ecosystem_model_selector/ecosystem_model_selector.py +195 -0
  47. supervisely/app/widgets/experiment_selector/experiment_selector.py +454 -263
  48. supervisely/app/widgets/fast_table/fast_table.py +713 -126
  49. supervisely/app/widgets/fast_table/script.js +492 -95
  50. supervisely/app/widgets/fast_table/style.css +54 -0
  51. supervisely/app/widgets/fast_table/template.html +45 -5
  52. supervisely/app/widgets/heatmap/__init__.py +0 -0
  53. supervisely/app/widgets/heatmap/heatmap.py +523 -0
  54. supervisely/app/widgets/heatmap/script.js +378 -0
  55. supervisely/app/widgets/heatmap/style.css +227 -0
  56. supervisely/app/widgets/heatmap/template.html +21 -0
  57. supervisely/app/widgets/input_tag/input_tag.py +102 -15
  58. supervisely/app/widgets/input_tag_list/__init__.py +0 -0
  59. supervisely/app/widgets/input_tag_list/input_tag_list.py +274 -0
  60. supervisely/app/widgets/input_tag_list/template.html +70 -0
  61. supervisely/app/widgets/radio_table/radio_table.py +10 -2
  62. supervisely/app/widgets/radio_tabs/radio_tabs.py +18 -2
  63. supervisely/app/widgets/radio_tabs/template.html +1 -0
  64. supervisely/app/widgets/select/select.py +6 -4
  65. supervisely/app/widgets/select_dataset/select_dataset.py +6 -0
  66. supervisely/app/widgets/select_dataset_tree/select_dataset_tree.py +83 -7
  67. supervisely/app/widgets/table/table.py +68 -13
  68. supervisely/app/widgets/tabs/tabs.py +22 -6
  69. supervisely/app/widgets/tabs/template.html +5 -1
  70. supervisely/app/widgets/transfer/style.css +3 -0
  71. supervisely/app/widgets/transfer/template.html +3 -1
  72. supervisely/app/widgets/transfer/transfer.py +48 -45
  73. supervisely/app/widgets/tree_select/tree_select.py +2 -0
  74. supervisely/convert/image/csv/csv_converter.py +24 -15
  75. supervisely/convert/pointcloud/nuscenes_conv/nuscenes_converter.py +43 -41
  76. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_converter.py +75 -51
  77. supervisely/convert/pointcloud_episodes/nuscenes_conv/nuscenes_helper.py +137 -124
  78. supervisely/convert/video/video_converter.py +2 -2
  79. supervisely/geometry/polyline_3d.py +110 -0
  80. supervisely/io/env.py +161 -1
  81. supervisely/nn/artifacts/__init__.py +1 -1
  82. supervisely/nn/artifacts/artifacts.py +10 -2
  83. supervisely/nn/artifacts/detectron2.py +1 -0
  84. supervisely/nn/artifacts/hrda.py +1 -0
  85. supervisely/nn/artifacts/mmclassification.py +20 -0
  86. supervisely/nn/artifacts/mmdetection.py +5 -3
  87. supervisely/nn/artifacts/mmsegmentation.py +1 -0
  88. supervisely/nn/artifacts/ritm.py +1 -0
  89. supervisely/nn/artifacts/rtdetr.py +1 -0
  90. supervisely/nn/artifacts/unet.py +1 -0
  91. supervisely/nn/artifacts/utils.py +3 -0
  92. supervisely/nn/artifacts/yolov5.py +2 -0
  93. supervisely/nn/artifacts/yolov8.py +1 -0
  94. supervisely/nn/benchmark/semantic_segmentation/metric_provider.py +18 -18
  95. supervisely/nn/experiments.py +9 -0
  96. supervisely/nn/inference/cache.py +37 -17
  97. supervisely/nn/inference/gui/serving_gui_template.py +39 -13
  98. supervisely/nn/inference/inference.py +953 -211
  99. supervisely/nn/inference/inference_request.py +15 -8
  100. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  101. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  102. supervisely/nn/inference/predict_app/__init__.py +0 -0
  103. supervisely/nn/inference/predict_app/gui/__init__.py +0 -0
  104. supervisely/nn/inference/predict_app/gui/classes_selector.py +160 -0
  105. supervisely/nn/inference/predict_app/gui/gui.py +915 -0
  106. supervisely/nn/inference/predict_app/gui/input_selector.py +344 -0
  107. supervisely/nn/inference/predict_app/gui/model_selector.py +77 -0
  108. supervisely/nn/inference/predict_app/gui/output_selector.py +179 -0
  109. supervisely/nn/inference/predict_app/gui/preview.py +93 -0
  110. supervisely/nn/inference/predict_app/gui/settings_selector.py +881 -0
  111. supervisely/nn/inference/predict_app/gui/tags_selector.py +110 -0
  112. supervisely/nn/inference/predict_app/gui/utils.py +399 -0
  113. supervisely/nn/inference/predict_app/predict_app.py +176 -0
  114. supervisely/nn/inference/session.py +47 -39
  115. supervisely/nn/inference/tracking/bbox_tracking.py +5 -1
  116. supervisely/nn/inference/tracking/point_tracking.py +5 -1
  117. supervisely/nn/inference/tracking/tracker_interface.py +4 -0
  118. supervisely/nn/inference/uploader.py +9 -5
  119. supervisely/nn/model/model_api.py +44 -22
  120. supervisely/nn/model/prediction.py +15 -1
  121. supervisely/nn/model/prediction_session.py +70 -14
  122. supervisely/nn/prediction_dto.py +7 -0
  123. supervisely/nn/tracker/__init__.py +6 -8
  124. supervisely/nn/tracker/base_tracker.py +54 -0
  125. supervisely/nn/tracker/botsort/__init__.py +1 -0
  126. supervisely/nn/tracker/botsort/botsort_config.yaml +30 -0
  127. supervisely/nn/tracker/botsort/osnet_reid/__init__.py +0 -0
  128. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  129. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  130. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  131. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  132. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  133. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  134. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  135. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  136. supervisely/nn/tracker/botsort_tracker.py +273 -0
  137. supervisely/nn/tracker/calculate_metrics.py +264 -0
  138. supervisely/nn/tracker/utils.py +273 -0
  139. supervisely/nn/tracker/visualize.py +520 -0
  140. supervisely/nn/training/gui/gui.py +152 -49
  141. supervisely/nn/training/gui/hyperparameters_selector.py +1 -1
  142. supervisely/nn/training/gui/model_selector.py +8 -6
  143. supervisely/nn/training/gui/train_val_splits_selector.py +144 -71
  144. supervisely/nn/training/gui/training_artifacts.py +3 -1
  145. supervisely/nn/training/train_app.py +225 -46
  146. supervisely/project/pointcloud_episode_project.py +12 -8
  147. supervisely/project/pointcloud_project.py +12 -8
  148. supervisely/project/project.py +221 -75
  149. supervisely/template/experiment/experiment.html.jinja +105 -55
  150. supervisely/template/experiment/experiment_generator.py +258 -112
  151. supervisely/template/experiment/header.html.jinja +31 -13
  152. supervisely/template/experiment/sly-style.css +7 -2
  153. supervisely/versions.json +3 -1
  154. supervisely/video/sampling.py +42 -20
  155. supervisely/video/video.py +41 -12
  156. supervisely/video_annotation/video_figure.py +38 -4
  157. supervisely/volume/stl_converter.py +2 -0
  158. supervisely/worker_api/agent_rpc.py +24 -1
  159. supervisely/worker_api/rpc_servicer.py +31 -7
  160. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/METADATA +22 -14
  161. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/RECORD +167 -148
  162. supervisely_lib/__init__.py +6 -1
  163. supervisely/app/widgets/experiment_selector/style.css +0 -27
  164. supervisely/app/widgets/experiment_selector/template.html +0 -61
  165. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  166. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  167. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  168. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  169. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  170. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  171. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  172. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  173. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  174. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  175. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  176. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  177. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  178. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  179. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  180. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  181. supervisely/nn/tracker/tracker.py +0 -285
  182. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  183. supervisely/nn/tracking/__init__.py +0 -1
  184. supervisely/nn/tracking/boxmot.py +0 -114
  185. supervisely/nn/tracking/tracking.py +0 -24
  186. /supervisely/{nn/tracker/utils → app/widgets/deploy_model}/__init__.py +0 -0
  187. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/LICENSE +0 -0
  188. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/WHEEL +0 -0
  189. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/entry_points.txt +0 -0
  190. {supervisely-6.73.410.dist-info → supervisely-6.73.470.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,54 @@
1
+ from typing import List, Dict, Any
2
+ import supervisely as sly
3
+ from supervisely import Annotation, VideoAnnotation
4
+ import numpy as np
5
+
6
+ class BaseTracker:
7
+
8
+ def __init__(self, settings: dict = None, device: str = None):
9
+ import torch # pylint: disable=import-error
10
+ self.settings = settings or {}
11
+ auto_device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ settings_device = self.settings.get("device")
13
+
14
+ if settings_device is not None:
15
+ if settings_device == "auto":
16
+ self.device = auto_device
17
+ else:
18
+ self.device = settings_device
19
+ else:
20
+ self.device = device if device is not None else auto_device
21
+
22
+ self._validate_device()
23
+
24
+
25
+ def update(self, frame: np.ndarray, annotation: Annotation) -> List[Dict[str, Any]]:
26
+ raise NotImplementedError("This method should be overridden by subclasses.")
27
+
28
+ def reset(self) -> None:
29
+ """Reset tracker state."""
30
+ pass
31
+
32
+ def track(self, frames: List[np.ndarray], annotations: List[Annotation]) -> VideoAnnotation:
33
+ raise NotImplementedError("This method should be overridden by subclasses.")
34
+
35
+ @property
36
+ def video_annotation(self) -> VideoAnnotation:
37
+ """Return the accumulated VideoAnnotation."""
38
+ raise NotImplementedError("This method should be overridden by subclasses.")
39
+
40
+ @classmethod
41
+ def get_default_params(cls) -> Dict[str, Any]:
42
+ """
43
+ Get default configurable parameters for this tracker.
44
+ Must be implemented in subclass.
45
+ """
46
+ raise NotImplementedError(
47
+ f"Method get_default_params() must be implemented in {cls.__name__}"
48
+ )
49
+
50
+ def _validate_device(self) -> None:
51
+ if self.device != 'cpu' and not self.device.startswith('cuda'):
52
+ raise ValueError(
53
+ f"Invalid device '{self.device}'. Supported devices are 'cpu' or 'cuda'."
54
+ )
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,30 @@
1
+ # botsort_default_config.yaml
2
+
3
+ device: "auto" # "cuda" or "cpu", "auto" will use cuda if available
4
+ fp16: false
5
+
6
+ # BoTSORT tracking parameters
7
+ track_high_thresh: 0.6
8
+ track_low_thresh: 0.1
9
+ new_track_thresh: 0.7
10
+ track_buffer: 30
11
+ match_thresh: 0.8
12
+ min_box_area: 10.0
13
+
14
+ # Appearance model (ReID)
15
+ with_reid: true
16
+ reid_weights: null
17
+ proximity_thresh: 0.5
18
+ appearance_thresh: 0.25
19
+
20
+ # Algorithm flags
21
+ fuse_score: false
22
+ ablation: false
23
+ mot20: false
24
+
25
+ # Camera motion compensation
26
+ cmc_method: "sparseOptFlow"
27
+
28
+ # Performance
29
+ fps: 30
30
+
File without changes
@@ -0,0 +1,566 @@
1
+
2
+ from __future__ import absolute_import, division
3
+
4
+ import warnings
5
+ from supervisely import logger
6
+
7
+ try:
8
+ import torch # pylint: disable=import-error
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ except ImportError:
12
+ logger.warning("torch is not installed, OSNet re-ID cannot be used.")
13
+
14
+
15
+ __all__ = ["osnet_x1_0", "osnet_x0_75", "osnet_x0_5", "osnet_x0_25", "osnet_ibn_x1_0"]
16
+
17
+ pretrained_urls = {
18
+ "osnet_x1_0": "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY",
19
+ "osnet_x0_75": "https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq",
20
+ "osnet_x0_5": "https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i",
21
+ "osnet_x0_25": "https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs",
22
+ "osnet_ibn_x1_0": "https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l",
23
+ }
24
+
25
+
26
+ ##########
27
+ # Basic layers
28
+ ##########
29
+ class ConvLayer(nn.Module):
30
+ """Convolution layer (conv + bn + relu)."""
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size,
37
+ stride=1,
38
+ padding=0,
39
+ groups=1,
40
+ IN=False,
41
+ ):
42
+ super(ConvLayer, self).__init__()
43
+ self.conv = nn.Conv2d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ bias=False,
50
+ groups=groups,
51
+ )
52
+ if IN:
53
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
54
+ else:
55
+ self.bn = nn.BatchNorm2d(out_channels)
56
+ self.relu = nn.ReLU(inplace=True)
57
+
58
+ def forward(self, x):
59
+ x = self.conv(x)
60
+ x = self.bn(x)
61
+ x = self.relu(x)
62
+ return x
63
+
64
+
65
+ class Conv1x1(nn.Module):
66
+ """1x1 convolution + bn + relu."""
67
+
68
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
69
+ super(Conv1x1, self).__init__()
70
+ self.conv = nn.Conv2d(
71
+ in_channels,
72
+ out_channels,
73
+ 1,
74
+ stride=stride,
75
+ padding=0,
76
+ bias=False,
77
+ groups=groups,
78
+ )
79
+ self.bn = nn.BatchNorm2d(out_channels)
80
+ self.relu = nn.ReLU(inplace=True)
81
+
82
+ def forward(self, x):
83
+ x = self.conv(x)
84
+ x = self.bn(x)
85
+ x = self.relu(x)
86
+ return x
87
+
88
+
89
+ class Conv1x1Linear(nn.Module):
90
+ """1x1 convolution + bn (w/o non-linearity)."""
91
+
92
+ def __init__(self, in_channels, out_channels, stride=1):
93
+ super(Conv1x1Linear, self).__init__()
94
+ self.conv = nn.Conv2d(
95
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
96
+ )
97
+ self.bn = nn.BatchNorm2d(out_channels)
98
+
99
+ def forward(self, x):
100
+ x = self.conv(x)
101
+ x = self.bn(x)
102
+ return x
103
+
104
+
105
+ class Conv3x3(nn.Module):
106
+ """3x3 convolution + bn + relu."""
107
+
108
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
109
+ super(Conv3x3, self).__init__()
110
+ self.conv = nn.Conv2d(
111
+ in_channels,
112
+ out_channels,
113
+ 3,
114
+ stride=stride,
115
+ padding=1,
116
+ bias=False,
117
+ groups=groups,
118
+ )
119
+ self.bn = nn.BatchNorm2d(out_channels)
120
+ self.relu = nn.ReLU(inplace=True)
121
+
122
+ def forward(self, x):
123
+ x = self.conv(x)
124
+ x = self.bn(x)
125
+ x = self.relu(x)
126
+ return x
127
+
128
+
129
+ class LightConv3x3(nn.Module):
130
+ """Lightweight 3x3 convolution.
131
+
132
+ 1x1 (linear) + dw 3x3 (nonlinear).
133
+ """
134
+
135
+ def __init__(self, in_channels, out_channels):
136
+ super(LightConv3x3, self).__init__()
137
+ self.conv1 = nn.Conv2d(
138
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
139
+ )
140
+ self.conv2 = nn.Conv2d(
141
+ out_channels,
142
+ out_channels,
143
+ 3,
144
+ stride=1,
145
+ padding=1,
146
+ bias=False,
147
+ groups=out_channels,
148
+ )
149
+ self.bn = nn.BatchNorm2d(out_channels)
150
+ self.relu = nn.ReLU(inplace=True)
151
+
152
+ def forward(self, x):
153
+ x = self.conv1(x)
154
+ x = self.conv2(x)
155
+ x = self.bn(x)
156
+ x = self.relu(x)
157
+ return x
158
+
159
+
160
+ ##########
161
+ # Building blocks for omni-scale feature learning
162
+ ##########
163
+ class ChannelGate(nn.Module):
164
+ """A mini-network that generates channel-wise gates conditioned on input tensor."""
165
+
166
+ def __init__(
167
+ self,
168
+ in_channels,
169
+ num_gates=None,
170
+ return_gates=False,
171
+ gate_activation="sigmoid",
172
+ reduction=16,
173
+ layer_norm=False,
174
+ ):
175
+ super(ChannelGate, self).__init__()
176
+ if num_gates is None:
177
+ num_gates = in_channels
178
+ self.return_gates = return_gates
179
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
180
+ self.fc1 = nn.Conv2d(
181
+ in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0
182
+ )
183
+ self.norm1 = None
184
+ if layer_norm:
185
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
186
+ self.relu = nn.ReLU(inplace=True)
187
+ self.fc2 = nn.Conv2d(
188
+ in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0
189
+ )
190
+ if gate_activation == "sigmoid":
191
+ self.gate_activation = nn.Sigmoid()
192
+ elif gate_activation == "relu":
193
+ self.gate_activation = nn.ReLU(inplace=True)
194
+ elif gate_activation == "linear":
195
+ self.gate_activation = None
196
+ else:
197
+ raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
198
+
199
+ def forward(self, x):
200
+ input = x
201
+ x = self.global_avgpool(x)
202
+ x = self.fc1(x)
203
+ if self.norm1 is not None:
204
+ x = self.norm1(x)
205
+ x = self.relu(x)
206
+ x = self.fc2(x)
207
+ if self.gate_activation is not None:
208
+ x = self.gate_activation(x)
209
+ if self.return_gates:
210
+ return x
211
+ return input * x
212
+
213
+
214
+ class OSBlock(nn.Module):
215
+ """Omni-scale feature learning block."""
216
+
217
+ def __init__(
218
+ self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs
219
+ ):
220
+ super(OSBlock, self).__init__()
221
+ mid_channels = out_channels // bottleneck_reduction
222
+ self.conv1 = Conv1x1(in_channels, mid_channels)
223
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
224
+ self.conv2b = nn.Sequential(
225
+ LightConv3x3(mid_channels, mid_channels),
226
+ LightConv3x3(mid_channels, mid_channels),
227
+ )
228
+ self.conv2c = nn.Sequential(
229
+ LightConv3x3(mid_channels, mid_channels),
230
+ LightConv3x3(mid_channels, mid_channels),
231
+ LightConv3x3(mid_channels, mid_channels),
232
+ )
233
+ self.conv2d = nn.Sequential(
234
+ LightConv3x3(mid_channels, mid_channels),
235
+ LightConv3x3(mid_channels, mid_channels),
236
+ LightConv3x3(mid_channels, mid_channels),
237
+ LightConv3x3(mid_channels, mid_channels),
238
+ )
239
+ self.gate = ChannelGate(mid_channels)
240
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
241
+ self.downsample = None
242
+ if in_channels != out_channels:
243
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
244
+ self.IN = None
245
+ if IN:
246
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True)
247
+
248
+ def forward(self, x):
249
+ identity = x
250
+ x1 = self.conv1(x)
251
+ x2a = self.conv2a(x1)
252
+ x2b = self.conv2b(x1)
253
+ x2c = self.conv2c(x1)
254
+ x2d = self.conv2d(x1)
255
+ x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
256
+ x3 = self.conv3(x2)
257
+ if self.downsample is not None:
258
+ identity = self.downsample(identity)
259
+ out = x3 + identity
260
+ if self.IN is not None:
261
+ out = self.IN(out)
262
+ return F.relu(out)
263
+
264
+
265
+ ##########
266
+ # Network architecture
267
+ ##########
268
+ class OSNet(nn.Module):
269
+ """Omni-Scale Network.
270
+
271
+ Reference:
272
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
273
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
274
+ for Person Re-Identification. TPAMI, 2021.
275
+ """
276
+
277
+ def __init__(
278
+ self,
279
+ num_classes,
280
+ blocks,
281
+ layers,
282
+ channels,
283
+ feature_dim=512,
284
+ loss="softmax",
285
+ IN=False,
286
+ **kwargs,
287
+ ):
288
+ super(OSNet, self).__init__()
289
+ num_blocks = len(blocks)
290
+ assert num_blocks == len(layers)
291
+ assert num_blocks == len(channels) - 1
292
+ self.loss = loss
293
+ self.feature_dim = feature_dim
294
+
295
+ # convolutional backbone
296
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
297
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
298
+ self.conv2 = self._make_layer(
299
+ blocks[0],
300
+ layers[0],
301
+ channels[0],
302
+ channels[1],
303
+ reduce_spatial_size=True,
304
+ IN=IN,
305
+ )
306
+ self.conv3 = self._make_layer(
307
+ blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True
308
+ )
309
+ self.conv4 = self._make_layer(
310
+ blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False
311
+ )
312
+ self.conv5 = Conv1x1(channels[3], channels[3])
313
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
314
+ # fully connected layer
315
+ self.fc = self._construct_fc_layer(
316
+ self.feature_dim, channels[3], dropout_p=None
317
+ )
318
+ # identity classification layer
319
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
320
+
321
+ self._init_params()
322
+
323
+ def _make_layer(
324
+ self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False
325
+ ):
326
+ layers = []
327
+
328
+ layers.append(block(in_channels, out_channels, IN=IN))
329
+ for i in range(1, layer):
330
+ layers.append(block(out_channels, out_channels, IN=IN))
331
+
332
+ if reduce_spatial_size:
333
+ layers.append(
334
+ nn.Sequential(
335
+ Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2)
336
+ )
337
+ )
338
+
339
+ return nn.Sequential(*layers)
340
+
341
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
342
+ if fc_dims is None or fc_dims < 0:
343
+ self.feature_dim = input_dim
344
+ return None
345
+
346
+ if isinstance(fc_dims, int):
347
+ fc_dims = [fc_dims]
348
+
349
+ layers = []
350
+ for dim in fc_dims:
351
+ layers.append(nn.Linear(input_dim, dim))
352
+ layers.append(nn.BatchNorm1d(dim))
353
+ layers.append(nn.ReLU(inplace=True))
354
+ if dropout_p is not None:
355
+ layers.append(nn.Dropout(p=dropout_p))
356
+ input_dim = dim
357
+
358
+ self.feature_dim = fc_dims[-1]
359
+
360
+ return nn.Sequential(*layers)
361
+
362
+ def _init_params(self):
363
+ for m in self.modules():
364
+ if isinstance(m, nn.Conv2d):
365
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
366
+ if m.bias is not None:
367
+ nn.init.constant_(m.bias, 0)
368
+
369
+ elif isinstance(m, nn.BatchNorm2d):
370
+ nn.init.constant_(m.weight, 1)
371
+ nn.init.constant_(m.bias, 0)
372
+
373
+ elif isinstance(m, nn.BatchNorm1d):
374
+ nn.init.constant_(m.weight, 1)
375
+ nn.init.constant_(m.bias, 0)
376
+
377
+ elif isinstance(m, nn.Linear):
378
+ nn.init.normal_(m.weight, 0, 0.01)
379
+ if m.bias is not None:
380
+ nn.init.constant_(m.bias, 0)
381
+
382
+ def featuremaps(self, x):
383
+ x = self.conv1(x)
384
+ x = self.maxpool(x)
385
+ x = self.conv2(x)
386
+ x = self.conv3(x)
387
+ x = self.conv4(x)
388
+ x = self.conv5(x)
389
+ return x
390
+
391
+ def forward(self, x, return_featuremaps=False):
392
+ x = self.featuremaps(x)
393
+ if return_featuremaps:
394
+ return x
395
+ v = self.global_avgpool(x)
396
+ v = v.view(v.size(0), -1)
397
+ if self.fc is not None:
398
+ v = self.fc(v)
399
+ if not self.training:
400
+ return v
401
+ y = self.classifier(v)
402
+ if self.loss == "softmax":
403
+ return y
404
+ elif self.loss == "triplet":
405
+ return y, v
406
+ else:
407
+ raise KeyError("Unsupported loss: {}".format(self.loss))
408
+
409
+
410
+ def init_pretrained_weights(model, key=""):
411
+ """Initializes model with pretrained weights.
412
+
413
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
414
+ """
415
+ import errno
416
+ import os
417
+ from collections import OrderedDict
418
+
419
+ import gdown # pylint: disable=import-error
420
+
421
+ def _get_torch_home():
422
+ ENV_TORCH_HOME = "TORCH_HOME"
423
+ ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
424
+ DEFAULT_CACHE_DIR = "~/.cache"
425
+ torch_home = os.path.expanduser(
426
+ os.getenv(
427
+ ENV_TORCH_HOME,
428
+ os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
429
+ )
430
+ )
431
+ return torch_home
432
+
433
+ torch_home = _get_torch_home()
434
+ model_dir = os.path.join(torch_home, "checkpoints")
435
+ try:
436
+ os.makedirs(model_dir)
437
+ except OSError as e:
438
+ if e.errno == errno.EEXIST:
439
+ # Directory already exists, ignore.
440
+ pass
441
+ else:
442
+ # Unexpected OSError, re-raise.
443
+ raise
444
+ filename = key + "_imagenet.pth"
445
+ cached_file = os.path.join(model_dir, filename)
446
+
447
+ if not os.path.exists(cached_file):
448
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
449
+
450
+ state_dict = torch.load(cached_file)
451
+ model_dict = model.state_dict()
452
+ new_state_dict = OrderedDict()
453
+ matched_layers, discarded_layers = [], []
454
+
455
+ for k, v in state_dict.items():
456
+ if k.startswith("module."):
457
+ k = k[7:] # discard module.
458
+
459
+ if k in model_dict and model_dict[k].size() == v.size():
460
+ new_state_dict[k] = v
461
+ matched_layers.append(k)
462
+ else:
463
+ discarded_layers.append(k)
464
+
465
+ model_dict.update(new_state_dict)
466
+ model.load_state_dict(model_dict)
467
+
468
+ if len(matched_layers) == 0:
469
+ warnings.warn(
470
+ 'The pretrained weights from "{}" cannot be loaded, '
471
+ "please check the key names manually "
472
+ "(** ignored and continue **)".format(cached_file)
473
+ )
474
+ else:
475
+ print(
476
+ 'Successfully loaded imagenet pretrained weights from "{}"'.format(
477
+ cached_file
478
+ )
479
+ )
480
+ if len(discarded_layers) > 0:
481
+ print(
482
+ "** The following layers are discarded "
483
+ "due to unmatched keys or layer size: {}".format(discarded_layers)
484
+ )
485
+
486
+
487
+ ##########
488
+ # Instantiation
489
+ ##########
490
+ def osnet_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
491
+ # standard size (width x1.0)
492
+ model = OSNet(
493
+ num_classes,
494
+ blocks=[OSBlock, OSBlock, OSBlock],
495
+ layers=[2, 2, 2],
496
+ channels=[64, 256, 384, 512],
497
+ loss=loss,
498
+ **kwargs,
499
+ )
500
+ if pretrained:
501
+ init_pretrained_weights(model, key="osnet_x1_0")
502
+ return model
503
+
504
+
505
+
506
+
507
+ def osnet_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
508
+ # medium size (width x0.75)
509
+ model = OSNet(
510
+ num_classes,
511
+ blocks=[OSBlock, OSBlock, OSBlock],
512
+ layers=[2, 2, 2],
513
+ channels=[48, 192, 288, 384],
514
+ loss=loss,
515
+ **kwargs,
516
+ )
517
+ if pretrained:
518
+ init_pretrained_weights(model, key="osnet_x0_75")
519
+ return model
520
+
521
+
522
+ def osnet_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
523
+ # tiny size (width x0.5)
524
+ model = OSNet(
525
+ num_classes,
526
+ blocks=[OSBlock, OSBlock, OSBlock],
527
+ layers=[2, 2, 2],
528
+ channels=[32, 128, 192, 256],
529
+ loss=loss,
530
+ **kwargs,
531
+ )
532
+ if pretrained:
533
+ init_pretrained_weights(model, key="osnet_x0_5")
534
+ return model
535
+
536
+
537
+ def osnet_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
538
+ # very tiny size (width x0.25)
539
+ model = OSNet(
540
+ num_classes,
541
+ blocks=[OSBlock, OSBlock, OSBlock],
542
+ layers=[2, 2, 2],
543
+ channels=[16, 64, 96, 128],
544
+ loss=loss,
545
+ **kwargs,
546
+ )
547
+ if pretrained:
548
+ init_pretrained_weights(model, key="osnet_x0_25")
549
+ return model
550
+
551
+
552
+ def osnet_ibn_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
553
+ # standard size (width x1.0) + IBN layer
554
+ # Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
555
+ model = OSNet(
556
+ num_classes,
557
+ blocks=[OSBlock, OSBlock, OSBlock],
558
+ layers=[2, 2, 2],
559
+ channels=[64, 256, 384, 512],
560
+ loss=loss,
561
+ IN=True,
562
+ **kwargs,
563
+ )
564
+ if pretrained:
565
+ init_pretrained_weights(model, key="osnet_ibn_x1_0")
566
+ return model