supervisely 6.73.417__py3-none-any.whl → 6.73.419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. supervisely/api/entity_annotation/figure_api.py +89 -45
  2. supervisely/nn/inference/inference.py +61 -45
  3. supervisely/nn/inference/instance_segmentation/instance_segmentation.py +1 -0
  4. supervisely/nn/inference/object_detection/object_detection.py +1 -0
  5. supervisely/nn/inference/session.py +4 -4
  6. supervisely/nn/model/model_api.py +31 -20
  7. supervisely/nn/model/prediction.py +11 -0
  8. supervisely/nn/model/prediction_session.py +33 -6
  9. supervisely/nn/tracker/__init__.py +1 -2
  10. supervisely/nn/tracker/base_tracker.py +44 -0
  11. supervisely/nn/tracker/botsort/__init__.py +1 -0
  12. supervisely/nn/tracker/botsort/botsort_config.yaml +31 -0
  13. supervisely/nn/tracker/botsort/osnet_reid/osnet.py +566 -0
  14. supervisely/nn/tracker/botsort/osnet_reid/osnet_reid_interface.py +88 -0
  15. supervisely/nn/tracker/botsort/tracker/__init__.py +0 -0
  16. supervisely/nn/tracker/{bot_sort → botsort/tracker}/basetrack.py +1 -2
  17. supervisely/nn/tracker/{utils → botsort/tracker}/gmc.py +51 -59
  18. supervisely/nn/tracker/{deep_sort/deep_sort → botsort/tracker}/kalman_filter.py +71 -33
  19. supervisely/nn/tracker/botsort/tracker/matching.py +202 -0
  20. supervisely/nn/tracker/{bot_sort/bot_sort.py → botsort/tracker/mc_bot_sort.py} +68 -81
  21. supervisely/nn/tracker/botsort_tracker.py +259 -0
  22. supervisely/project/project.py +1 -1
  23. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/METADATA +5 -3
  24. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/RECORD +29 -42
  25. supervisely/nn/tracker/bot_sort/__init__.py +0 -21
  26. supervisely/nn/tracker/bot_sort/fast_reid_interface.py +0 -152
  27. supervisely/nn/tracker/bot_sort/matching.py +0 -127
  28. supervisely/nn/tracker/bot_sort/sly_tracker.py +0 -401
  29. supervisely/nn/tracker/deep_sort/__init__.py +0 -6
  30. supervisely/nn/tracker/deep_sort/deep_sort/__init__.py +0 -1
  31. supervisely/nn/tracker/deep_sort/deep_sort/detection.py +0 -49
  32. supervisely/nn/tracker/deep_sort/deep_sort/iou_matching.py +0 -81
  33. supervisely/nn/tracker/deep_sort/deep_sort/linear_assignment.py +0 -202
  34. supervisely/nn/tracker/deep_sort/deep_sort/nn_matching.py +0 -176
  35. supervisely/nn/tracker/deep_sort/deep_sort/track.py +0 -166
  36. supervisely/nn/tracker/deep_sort/deep_sort/tracker.py +0 -145
  37. supervisely/nn/tracker/deep_sort/deep_sort.py +0 -301
  38. supervisely/nn/tracker/deep_sort/generate_clip_detections.py +0 -90
  39. supervisely/nn/tracker/deep_sort/preprocessing.py +0 -70
  40. supervisely/nn/tracker/deep_sort/sly_tracker.py +0 -273
  41. supervisely/nn/tracker/tracker.py +0 -285
  42. supervisely/nn/tracker/utils/kalman_filter.py +0 -492
  43. supervisely/nn/tracking/__init__.py +0 -1
  44. supervisely/nn/tracking/boxmot.py +0 -114
  45. supervisely/nn/tracking/tracking.py +0 -24
  46. /supervisely/nn/tracker/{utils → botsort/osnet_reid}/__init__.py +0 -0
  47. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/LICENSE +0 -0
  48. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/WHEEL +0 -0
  49. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/entry_points.txt +0 -0
  50. {supervisely-6.73.417.dist-info → supervisely-6.73.419.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,7 @@
1
1
  from supervisely.sly_logger import logger
2
2
 
3
3
  try:
4
- from supervisely.nn.tracker.bot_sort import BoTTracker
5
- from supervisely.nn.tracker.deep_sort import DeepSortTracker
4
+ from supervisely.nn.tracker.botsort_tracker import BotSortTracker
6
5
  except ImportError:
7
6
  logger.error(
8
7
  "Failed to import tracker modules. Please try install extras with 'pip install supervisely[tracking]'"
@@ -0,0 +1,44 @@
1
+ from typing import List, Dict, Any
2
+ import supervisely as sly
3
+ from supervisely import Annotation, VideoAnnotation
4
+ import numpy as np
5
+
6
+ class BaseTracker:
7
+
8
+ def __init__(self, settings: dict = None, device: str = None):
9
+ import torch # pylint: disable=import-error
10
+ self.settings = settings or {}
11
+ auto_device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ settings_device = self.settings.get("device")
13
+
14
+ if settings_device is not None:
15
+ if settings_device == "auto":
16
+ self.device = auto_device
17
+ else:
18
+ self.device = settings_device
19
+ else:
20
+ self.device = device if device is not None else auto_device
21
+
22
+ self._validate_device()
23
+
24
+
25
+ def update(self, frame: np.ndarray, annotation: Annotation) -> List[Dict[str, Any]]:
26
+ raise NotImplementedError("This method should be overridden by subclasses.")
27
+
28
+ def reset(self) -> None:
29
+ """Reset tracker state."""
30
+ pass
31
+
32
+ def track(self, frames: List[np.ndarray], annotations: List[Annotation]) -> VideoAnnotation:
33
+ raise NotImplementedError("This method should be overridden by subclasses.")
34
+
35
+ @property
36
+ def video_annotation(self) -> VideoAnnotation:
37
+ """Return the accumulated VideoAnnotation."""
38
+ raise NotImplementedError("This method should be overridden by subclasses.")
39
+
40
+ def _validate_device(self) -> None:
41
+ if self.device != 'cpu' and not self.device.startswith('cuda'):
42
+ raise ValueError(
43
+ f"Invalid device '{self.device}'. Supported devices are 'cpu' or 'cuda'."
44
+ )
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,31 @@
1
+ # botsort_default_config.yaml
2
+
3
+ name: "BotSORT"
4
+ device: "auto" # "cuda" or "cpu", "auto" will use cuda if available
5
+ fp16: false
6
+
7
+ # BoTSORT tracking parameters
8
+ track_high_thresh: 0.6
9
+ track_low_thresh: 0.1
10
+ new_track_thresh: 0.7
11
+ track_buffer: 30
12
+ match_thresh: 0.8
13
+ min_box_area: 10.0
14
+
15
+ # Appearance model (ReID)
16
+ with_reid: true
17
+ reid_weights: null
18
+ proximity_thresh: 0.5
19
+ appearance_thresh: 0.25
20
+
21
+ # Algorithm flags
22
+ fuse_score: false
23
+ ablation: false
24
+ mot20: false
25
+
26
+ # Camera motion compensation
27
+ cmc_method: "sparseOptFlow"
28
+
29
+ # Performance
30
+ fps: 30
31
+
@@ -0,0 +1,566 @@
1
+
2
+ from __future__ import absolute_import, division
3
+
4
+ import warnings
5
+ from supervisely import logger
6
+
7
+ try:
8
+ import torch # pylint: disable=import-error
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ except ImportError:
12
+ logger.warning("torch is not installed, OSNet re-ID cannot be used.")
13
+
14
+
15
+ __all__ = ["osnet_x1_0", "osnet_x0_75", "osnet_x0_5", "osnet_x0_25", "osnet_ibn_x1_0"]
16
+
17
+ pretrained_urls = {
18
+ "osnet_x1_0": "https://drive.google.com/uc?id=1LaG1EJpHrxdAxKnSCJ_i0u-nbxSAeiFY",
19
+ "osnet_x0_75": "https://drive.google.com/uc?id=1uwA9fElHOk3ZogwbeY5GkLI6QPTX70Hq",
20
+ "osnet_x0_5": "https://drive.google.com/uc?id=16DGLbZukvVYgINws8u8deSaOqjybZ83i",
21
+ "osnet_x0_25": "https://drive.google.com/uc?id=1rb8UN5ZzPKRc_xvtHlyDh-cSz88YX9hs",
22
+ "osnet_ibn_x1_0": "https://drive.google.com/uc?id=1sr90V6irlYYDd4_4ISU2iruoRG8J__6l",
23
+ }
24
+
25
+
26
+ ##########
27
+ # Basic layers
28
+ ##########
29
+ class ConvLayer(nn.Module):
30
+ """Convolution layer (conv + bn + relu)."""
31
+
32
+ def __init__(
33
+ self,
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size,
37
+ stride=1,
38
+ padding=0,
39
+ groups=1,
40
+ IN=False,
41
+ ):
42
+ super(ConvLayer, self).__init__()
43
+ self.conv = nn.Conv2d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size,
47
+ stride=stride,
48
+ padding=padding,
49
+ bias=False,
50
+ groups=groups,
51
+ )
52
+ if IN:
53
+ self.bn = nn.InstanceNorm2d(out_channels, affine=True)
54
+ else:
55
+ self.bn = nn.BatchNorm2d(out_channels)
56
+ self.relu = nn.ReLU(inplace=True)
57
+
58
+ def forward(self, x):
59
+ x = self.conv(x)
60
+ x = self.bn(x)
61
+ x = self.relu(x)
62
+ return x
63
+
64
+
65
+ class Conv1x1(nn.Module):
66
+ """1x1 convolution + bn + relu."""
67
+
68
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
69
+ super(Conv1x1, self).__init__()
70
+ self.conv = nn.Conv2d(
71
+ in_channels,
72
+ out_channels,
73
+ 1,
74
+ stride=stride,
75
+ padding=0,
76
+ bias=False,
77
+ groups=groups,
78
+ )
79
+ self.bn = nn.BatchNorm2d(out_channels)
80
+ self.relu = nn.ReLU(inplace=True)
81
+
82
+ def forward(self, x):
83
+ x = self.conv(x)
84
+ x = self.bn(x)
85
+ x = self.relu(x)
86
+ return x
87
+
88
+
89
+ class Conv1x1Linear(nn.Module):
90
+ """1x1 convolution + bn (w/o non-linearity)."""
91
+
92
+ def __init__(self, in_channels, out_channels, stride=1):
93
+ super(Conv1x1Linear, self).__init__()
94
+ self.conv = nn.Conv2d(
95
+ in_channels, out_channels, 1, stride=stride, padding=0, bias=False
96
+ )
97
+ self.bn = nn.BatchNorm2d(out_channels)
98
+
99
+ def forward(self, x):
100
+ x = self.conv(x)
101
+ x = self.bn(x)
102
+ return x
103
+
104
+
105
+ class Conv3x3(nn.Module):
106
+ """3x3 convolution + bn + relu."""
107
+
108
+ def __init__(self, in_channels, out_channels, stride=1, groups=1):
109
+ super(Conv3x3, self).__init__()
110
+ self.conv = nn.Conv2d(
111
+ in_channels,
112
+ out_channels,
113
+ 3,
114
+ stride=stride,
115
+ padding=1,
116
+ bias=False,
117
+ groups=groups,
118
+ )
119
+ self.bn = nn.BatchNorm2d(out_channels)
120
+ self.relu = nn.ReLU(inplace=True)
121
+
122
+ def forward(self, x):
123
+ x = self.conv(x)
124
+ x = self.bn(x)
125
+ x = self.relu(x)
126
+ return x
127
+
128
+
129
+ class LightConv3x3(nn.Module):
130
+ """Lightweight 3x3 convolution.
131
+
132
+ 1x1 (linear) + dw 3x3 (nonlinear).
133
+ """
134
+
135
+ def __init__(self, in_channels, out_channels):
136
+ super(LightConv3x3, self).__init__()
137
+ self.conv1 = nn.Conv2d(
138
+ in_channels, out_channels, 1, stride=1, padding=0, bias=False
139
+ )
140
+ self.conv2 = nn.Conv2d(
141
+ out_channels,
142
+ out_channels,
143
+ 3,
144
+ stride=1,
145
+ padding=1,
146
+ bias=False,
147
+ groups=out_channels,
148
+ )
149
+ self.bn = nn.BatchNorm2d(out_channels)
150
+ self.relu = nn.ReLU(inplace=True)
151
+
152
+ def forward(self, x):
153
+ x = self.conv1(x)
154
+ x = self.conv2(x)
155
+ x = self.bn(x)
156
+ x = self.relu(x)
157
+ return x
158
+
159
+
160
+ ##########
161
+ # Building blocks for omni-scale feature learning
162
+ ##########
163
+ class ChannelGate(nn.Module):
164
+ """A mini-network that generates channel-wise gates conditioned on input tensor."""
165
+
166
+ def __init__(
167
+ self,
168
+ in_channels,
169
+ num_gates=None,
170
+ return_gates=False,
171
+ gate_activation="sigmoid",
172
+ reduction=16,
173
+ layer_norm=False,
174
+ ):
175
+ super(ChannelGate, self).__init__()
176
+ if num_gates is None:
177
+ num_gates = in_channels
178
+ self.return_gates = return_gates
179
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
180
+ self.fc1 = nn.Conv2d(
181
+ in_channels, in_channels // reduction, kernel_size=1, bias=True, padding=0
182
+ )
183
+ self.norm1 = None
184
+ if layer_norm:
185
+ self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1))
186
+ self.relu = nn.ReLU(inplace=True)
187
+ self.fc2 = nn.Conv2d(
188
+ in_channels // reduction, num_gates, kernel_size=1, bias=True, padding=0
189
+ )
190
+ if gate_activation == "sigmoid":
191
+ self.gate_activation = nn.Sigmoid()
192
+ elif gate_activation == "relu":
193
+ self.gate_activation = nn.ReLU(inplace=True)
194
+ elif gate_activation == "linear":
195
+ self.gate_activation = None
196
+ else:
197
+ raise RuntimeError("Unknown gate activation: {}".format(gate_activation))
198
+
199
+ def forward(self, x):
200
+ input = x
201
+ x = self.global_avgpool(x)
202
+ x = self.fc1(x)
203
+ if self.norm1 is not None:
204
+ x = self.norm1(x)
205
+ x = self.relu(x)
206
+ x = self.fc2(x)
207
+ if self.gate_activation is not None:
208
+ x = self.gate_activation(x)
209
+ if self.return_gates:
210
+ return x
211
+ return input * x
212
+
213
+
214
+ class OSBlock(nn.Module):
215
+ """Omni-scale feature learning block."""
216
+
217
+ def __init__(
218
+ self, in_channels, out_channels, IN=False, bottleneck_reduction=4, **kwargs
219
+ ):
220
+ super(OSBlock, self).__init__()
221
+ mid_channels = out_channels // bottleneck_reduction
222
+ self.conv1 = Conv1x1(in_channels, mid_channels)
223
+ self.conv2a = LightConv3x3(mid_channels, mid_channels)
224
+ self.conv2b = nn.Sequential(
225
+ LightConv3x3(mid_channels, mid_channels),
226
+ LightConv3x3(mid_channels, mid_channels),
227
+ )
228
+ self.conv2c = nn.Sequential(
229
+ LightConv3x3(mid_channels, mid_channels),
230
+ LightConv3x3(mid_channels, mid_channels),
231
+ LightConv3x3(mid_channels, mid_channels),
232
+ )
233
+ self.conv2d = nn.Sequential(
234
+ LightConv3x3(mid_channels, mid_channels),
235
+ LightConv3x3(mid_channels, mid_channels),
236
+ LightConv3x3(mid_channels, mid_channels),
237
+ LightConv3x3(mid_channels, mid_channels),
238
+ )
239
+ self.gate = ChannelGate(mid_channels)
240
+ self.conv3 = Conv1x1Linear(mid_channels, out_channels)
241
+ self.downsample = None
242
+ if in_channels != out_channels:
243
+ self.downsample = Conv1x1Linear(in_channels, out_channels)
244
+ self.IN = None
245
+ if IN:
246
+ self.IN = nn.InstanceNorm2d(out_channels, affine=True)
247
+
248
+ def forward(self, x):
249
+ identity = x
250
+ x1 = self.conv1(x)
251
+ x2a = self.conv2a(x1)
252
+ x2b = self.conv2b(x1)
253
+ x2c = self.conv2c(x1)
254
+ x2d = self.conv2d(x1)
255
+ x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
256
+ x3 = self.conv3(x2)
257
+ if self.downsample is not None:
258
+ identity = self.downsample(identity)
259
+ out = x3 + identity
260
+ if self.IN is not None:
261
+ out = self.IN(out)
262
+ return F.relu(out)
263
+
264
+
265
+ ##########
266
+ # Network architecture
267
+ ##########
268
+ class OSNet(nn.Module):
269
+ """Omni-Scale Network.
270
+
271
+ Reference:
272
+ - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
273
+ - Zhou et al. Learning Generalisable Omni-Scale Representations
274
+ for Person Re-Identification. TPAMI, 2021.
275
+ """
276
+
277
+ def __init__(
278
+ self,
279
+ num_classes,
280
+ blocks,
281
+ layers,
282
+ channels,
283
+ feature_dim=512,
284
+ loss="softmax",
285
+ IN=False,
286
+ **kwargs,
287
+ ):
288
+ super(OSNet, self).__init__()
289
+ num_blocks = len(blocks)
290
+ assert num_blocks == len(layers)
291
+ assert num_blocks == len(channels) - 1
292
+ self.loss = loss
293
+ self.feature_dim = feature_dim
294
+
295
+ # convolutional backbone
296
+ self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
297
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
298
+ self.conv2 = self._make_layer(
299
+ blocks[0],
300
+ layers[0],
301
+ channels[0],
302
+ channels[1],
303
+ reduce_spatial_size=True,
304
+ IN=IN,
305
+ )
306
+ self.conv3 = self._make_layer(
307
+ blocks[1], layers[1], channels[1], channels[2], reduce_spatial_size=True
308
+ )
309
+ self.conv4 = self._make_layer(
310
+ blocks[2], layers[2], channels[2], channels[3], reduce_spatial_size=False
311
+ )
312
+ self.conv5 = Conv1x1(channels[3], channels[3])
313
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
314
+ # fully connected layer
315
+ self.fc = self._construct_fc_layer(
316
+ self.feature_dim, channels[3], dropout_p=None
317
+ )
318
+ # identity classification layer
319
+ self.classifier = nn.Linear(self.feature_dim, num_classes)
320
+
321
+ self._init_params()
322
+
323
+ def _make_layer(
324
+ self, block, layer, in_channels, out_channels, reduce_spatial_size, IN=False
325
+ ):
326
+ layers = []
327
+
328
+ layers.append(block(in_channels, out_channels, IN=IN))
329
+ for i in range(1, layer):
330
+ layers.append(block(out_channels, out_channels, IN=IN))
331
+
332
+ if reduce_spatial_size:
333
+ layers.append(
334
+ nn.Sequential(
335
+ Conv1x1(out_channels, out_channels), nn.AvgPool2d(2, stride=2)
336
+ )
337
+ )
338
+
339
+ return nn.Sequential(*layers)
340
+
341
+ def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
342
+ if fc_dims is None or fc_dims < 0:
343
+ self.feature_dim = input_dim
344
+ return None
345
+
346
+ if isinstance(fc_dims, int):
347
+ fc_dims = [fc_dims]
348
+
349
+ layers = []
350
+ for dim in fc_dims:
351
+ layers.append(nn.Linear(input_dim, dim))
352
+ layers.append(nn.BatchNorm1d(dim))
353
+ layers.append(nn.ReLU(inplace=True))
354
+ if dropout_p is not None:
355
+ layers.append(nn.Dropout(p=dropout_p))
356
+ input_dim = dim
357
+
358
+ self.feature_dim = fc_dims[-1]
359
+
360
+ return nn.Sequential(*layers)
361
+
362
+ def _init_params(self):
363
+ for m in self.modules():
364
+ if isinstance(m, nn.Conv2d):
365
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
366
+ if m.bias is not None:
367
+ nn.init.constant_(m.bias, 0)
368
+
369
+ elif isinstance(m, nn.BatchNorm2d):
370
+ nn.init.constant_(m.weight, 1)
371
+ nn.init.constant_(m.bias, 0)
372
+
373
+ elif isinstance(m, nn.BatchNorm1d):
374
+ nn.init.constant_(m.weight, 1)
375
+ nn.init.constant_(m.bias, 0)
376
+
377
+ elif isinstance(m, nn.Linear):
378
+ nn.init.normal_(m.weight, 0, 0.01)
379
+ if m.bias is not None:
380
+ nn.init.constant_(m.bias, 0)
381
+
382
+ def featuremaps(self, x):
383
+ x = self.conv1(x)
384
+ x = self.maxpool(x)
385
+ x = self.conv2(x)
386
+ x = self.conv3(x)
387
+ x = self.conv4(x)
388
+ x = self.conv5(x)
389
+ return x
390
+
391
+ def forward(self, x, return_featuremaps=False):
392
+ x = self.featuremaps(x)
393
+ if return_featuremaps:
394
+ return x
395
+ v = self.global_avgpool(x)
396
+ v = v.view(v.size(0), -1)
397
+ if self.fc is not None:
398
+ v = self.fc(v)
399
+ if not self.training:
400
+ return v
401
+ y = self.classifier(v)
402
+ if self.loss == "softmax":
403
+ return y
404
+ elif self.loss == "triplet":
405
+ return y, v
406
+ else:
407
+ raise KeyError("Unsupported loss: {}".format(self.loss))
408
+
409
+
410
+ def init_pretrained_weights(model, key=""):
411
+ """Initializes model with pretrained weights.
412
+
413
+ Layers that don't match with pretrained layers in name or size are kept unchanged.
414
+ """
415
+ import errno
416
+ import os
417
+ from collections import OrderedDict
418
+
419
+ import gdown # pylint: disable=import-error
420
+
421
+ def _get_torch_home():
422
+ ENV_TORCH_HOME = "TORCH_HOME"
423
+ ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
424
+ DEFAULT_CACHE_DIR = "~/.cache"
425
+ torch_home = os.path.expanduser(
426
+ os.getenv(
427
+ ENV_TORCH_HOME,
428
+ os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
429
+ )
430
+ )
431
+ return torch_home
432
+
433
+ torch_home = _get_torch_home()
434
+ model_dir = os.path.join(torch_home, "checkpoints")
435
+ try:
436
+ os.makedirs(model_dir)
437
+ except OSError as e:
438
+ if e.errno == errno.EEXIST:
439
+ # Directory already exists, ignore.
440
+ pass
441
+ else:
442
+ # Unexpected OSError, re-raise.
443
+ raise
444
+ filename = key + "_imagenet.pth"
445
+ cached_file = os.path.join(model_dir, filename)
446
+
447
+ if not os.path.exists(cached_file):
448
+ gdown.download(pretrained_urls[key], cached_file, quiet=False)
449
+
450
+ state_dict = torch.load(cached_file)
451
+ model_dict = model.state_dict()
452
+ new_state_dict = OrderedDict()
453
+ matched_layers, discarded_layers = [], []
454
+
455
+ for k, v in state_dict.items():
456
+ if k.startswith("module."):
457
+ k = k[7:] # discard module.
458
+
459
+ if k in model_dict and model_dict[k].size() == v.size():
460
+ new_state_dict[k] = v
461
+ matched_layers.append(k)
462
+ else:
463
+ discarded_layers.append(k)
464
+
465
+ model_dict.update(new_state_dict)
466
+ model.load_state_dict(model_dict)
467
+
468
+ if len(matched_layers) == 0:
469
+ warnings.warn(
470
+ 'The pretrained weights from "{}" cannot be loaded, '
471
+ "please check the key names manually "
472
+ "(** ignored and continue **)".format(cached_file)
473
+ )
474
+ else:
475
+ print(
476
+ 'Successfully loaded imagenet pretrained weights from "{}"'.format(
477
+ cached_file
478
+ )
479
+ )
480
+ if len(discarded_layers) > 0:
481
+ print(
482
+ "** The following layers are discarded "
483
+ "due to unmatched keys or layer size: {}".format(discarded_layers)
484
+ )
485
+
486
+
487
+ ##########
488
+ # Instantiation
489
+ ##########
490
+ def osnet_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
491
+ # standard size (width x1.0)
492
+ model = OSNet(
493
+ num_classes,
494
+ blocks=[OSBlock, OSBlock, OSBlock],
495
+ layers=[2, 2, 2],
496
+ channels=[64, 256, 384, 512],
497
+ loss=loss,
498
+ **kwargs,
499
+ )
500
+ if pretrained:
501
+ init_pretrained_weights(model, key="osnet_x1_0")
502
+ return model
503
+
504
+
505
+
506
+
507
+ def osnet_x0_75(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
508
+ # medium size (width x0.75)
509
+ model = OSNet(
510
+ num_classes,
511
+ blocks=[OSBlock, OSBlock, OSBlock],
512
+ layers=[2, 2, 2],
513
+ channels=[48, 192, 288, 384],
514
+ loss=loss,
515
+ **kwargs,
516
+ )
517
+ if pretrained:
518
+ init_pretrained_weights(model, key="osnet_x0_75")
519
+ return model
520
+
521
+
522
+ def osnet_x0_5(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
523
+ # tiny size (width x0.5)
524
+ model = OSNet(
525
+ num_classes,
526
+ blocks=[OSBlock, OSBlock, OSBlock],
527
+ layers=[2, 2, 2],
528
+ channels=[32, 128, 192, 256],
529
+ loss=loss,
530
+ **kwargs,
531
+ )
532
+ if pretrained:
533
+ init_pretrained_weights(model, key="osnet_x0_5")
534
+ return model
535
+
536
+
537
+ def osnet_x0_25(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
538
+ # very tiny size (width x0.25)
539
+ model = OSNet(
540
+ num_classes,
541
+ blocks=[OSBlock, OSBlock, OSBlock],
542
+ layers=[2, 2, 2],
543
+ channels=[16, 64, 96, 128],
544
+ loss=loss,
545
+ **kwargs,
546
+ )
547
+ if pretrained:
548
+ init_pretrained_weights(model, key="osnet_x0_25")
549
+ return model
550
+
551
+
552
+ def osnet_ibn_x1_0(num_classes=1000, pretrained=True, loss="softmax", **kwargs):
553
+ # standard size (width x1.0) + IBN layer
554
+ # Ref: Pan et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net. ECCV, 2018.
555
+ model = OSNet(
556
+ num_classes,
557
+ blocks=[OSBlock, OSBlock, OSBlock],
558
+ layers=[2, 2, 2],
559
+ channels=[64, 256, 384, 512],
560
+ loss=loss,
561
+ IN=True,
562
+ **kwargs,
563
+ )
564
+ if pretrained:
565
+ init_pretrained_weights(model, key="osnet_ibn_x1_0")
566
+ return model